Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Practical Conditional Neural Processes Via Tractable Dependent Predictions

About

Conditional Neural Processes (CNPs; Garnelo et al., 2018a) are meta-learning models which leverage the flexibility of deep learning to produce well-calibrated predictions and naturally handle off-the-grid and missing data. CNPs scale to large datasets and train with ease. Due to these features, CNPs appear well-suited to tasks from environmental sciences or healthcare. Unfortunately, CNPs do not produce correlated predictions, making them fundamentally inappropriate for many estimation and decision making tasks. Predicting heat waves or floods, for example, requires modelling dependencies in temperature or precipitation over time and space. Existing approaches which model output dependencies, such as Neural Processes (NPs; Garnelo et al., 2018b) or the FullConvGNP (Bruinsma et al., 2021), are either complicated to train or prohibitively expensive. What is needed is an approach which provides dependent predictions, but is simple to train and computationally tractable. In this work, we present a new class of Neural Process models that make correlated predictions and support exact maximum likelihood training that is simple and scalable. We extend the proposed models by using invertible output transformations, to capture non-Gaussian output distributions. Our models can be used in downstream estimation tasks which require dependent function samples. By accounting for output dependencies, our models show improved predictive performance on a range of experiments with synthetic and real data.

Stratis Markou, James Requeima, Wessel P. Bruinsma, Anna Vaughan, Richard E. Turner• 2022

Related benchmarks

TaskDatasetResultRank
RegressionSynthetic weakly-periodic Interpolation (INT)
Normalized KL Divergence0.01
43
RegressionSynthetic (weakly-periodic) Out-of-input-distribution (OOID)
NKL Divergence0.01
39
RegressionSynthetic Matérn-5/2 Interpolation (INT)
Normalized KL Divergence0.00e+0
32
Synthetic RegressionSynthetic GP Matérn-5/2 kernel Out-of-input-distribution
KL Divergence (Normalized)0.00e+0
31
RegressionSynthetic Sawtooth (Interpolation)
Normalized Log-Likelihood4.11
29
RegressionSynthetic Sawtooth (Out-of-input-distribution)
Normalized Log-Likelihood3.99
24
RegressionSynthetic Sawtooth INT
Log-Likelihood3.94
14
RegressionSynthetic Mixture INT
Log-Likelihood0.49
14
Synthetic RegressionSynthetic mixture dx=2
Normalized Log-Likelihood0.87
13
InterpolationLotka-Volterra Simulated
Normalized Log-Likelihood-3.46
12
Showing 10 of 33 rows

Other info

Follow for update