Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Neural Processes with Stochastic Attention: Paying more attention to the context dataset

About

Neural processes (NPs) aim to stochastically complete unseen data points based on a given context dataset. NPs essentially leverage a given dataset as a context representation to derive a suitable identifier for a novel task. To improve the prediction accuracy, many variants of NPs have investigated context embedding approaches that generally design novel network architectures and aggregation functions satisfying permutation invariant. In this work, we propose a stochastic attention mechanism for NPs to capture appropriate context information. From the perspective of information theory, we demonstrate that the proposed method encourages context embedding to be differentiated from a target dataset, allowing NPs to consider features in a target dataset and context embedding independently. We observe that the proposed method can appropriately capture context embedding even under noisy data sets and restricted task distributions, where typical NPs suffer from a lack of context embeddings. We empirically show that our approach substantially outperforms conventional NPs in various domains through 1D regression, predator-prey model, and image completion. Moreover, the proposed method is also validated by MovieLens-10k dataset, a real-world problem.

Mingyu Kim, Kyeongryeol Go, Se-Young Yun• 2022

Related benchmarks

TaskDatasetResultRank
1D RegressionSynthetic 1D Regression RBF kernel with noises
Context Likelihood1.374
16
1D RegressionSynthetic 1D Regression RBF kernel
Context Likelihood1.363
16
1D RegressionSynthetic 1D Regression Matern kernel GP
Context Likelihood1.365
16
1D RegressionSynthetic 1D Regression Periodic kernel GP
Context Likelihood1.372
16
Sim2Real RegressionPredator-Prey Simulation
Context Likelihood271.1
16
Sim2Real RegressionPredator-Prey Real
Context Likelihood2.429
16
Rating PredictionMovieLens 100k U1 (test)
RMSE0.895
15
Image CompletionCelebA
Context Likelihood (Avg)4.119
14
Likelihood EstimationMovieLens 10k (test)
Context Likelihood-0.349
14
Showing 9 of 9 rows

Other info

Code

Follow for update