Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Practical Equivariances via Relational Conditional Neural Processes

About

Conditional Neural Processes (CNPs) are a class of metalearning models popular for combining the runtime efficiency of amortized inference with reliable uncertainty quantification. Many relevant machine learning tasks, such as in spatio-temporal modeling, Bayesian Optimization and continuous control, inherently contain equivariances -- for example to translation -- which the model can exploit for maximal performance. However, prior attempts to include equivariances in CNPs do not scale effectively beyond two input dimensions. In this work, we propose Relational Conditional Neural Processes (RCNPs), an effective approach to incorporate equivariances into any neural process model. Our proposed method extends the applicability and impact of equivariant neural processes to higher dimensions. We empirically demonstrate the competitive performance of RCNPs on a large array of tasks naturally containing equivariances.

Daolang Huang, Manuel Haussmann, Ulpu Remes, ST John, Gr\'egoire Clart\'e, Kevin Sebastian Luck, Samuel Kaski, Luigi Acerbi• 2023

Related benchmarks

TaskDatasetResultRank
RegressionSynthetic weakly-periodic Interpolation (INT)
Normalized KL Divergence0.03
43
RegressionSynthetic (weakly-periodic) Out-of-input-distribution (OOID)
NKL Divergence0.03
39
RegressionSynthetic Matérn-5/2 Interpolation (INT)
Normalized KL Divergence0.01
32
Synthetic RegressionSynthetic GP Matérn-5/2 kernel Out-of-input-distribution
KL Divergence (Normalized)0.01
31
RegressionSynthetic Sawtooth (Interpolation)
Normalized Log-Likelihood3.9
29
RegressionSynthetic Sawtooth (Out-of-input-distribution)
Normalized Log-Likelihood3.9
24
RegressionSynthetic Sawtooth INT
Log-Likelihood3.9
14
RegressionSynthetic Mixture INT
Log-Likelihood0.37
14
Synthetic RegressionSynthetic mixture dx=2
Normalized Log-Likelihood0.46
13
CompletionReaction-Diffusion (test)
Normalized Log-Likelihood1.38
12
Showing 10 of 38 rows

Other info

Code

Follow for update