Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Latent Bottlenecked Attentive Neural Processes

About

Neural Processes (NPs) are popular methods in meta-learning that can estimate predictive uncertainty on target datapoints by conditioning on a context dataset. Previous state-of-the-art method Transformer Neural Processes (TNPs) achieve strong performance but require quadratic computation with respect to the number of context datapoints, significantly limiting its scalability. Conversely, existing sub-quadratic NP variants perform significantly worse than that of TNPs. Tackling this issue, we propose Latent Bottlenecked Attentive Neural Processes (LBANPs), a new computationally efficient sub-quadratic NP variant, that has a querying computational complexity independent of the number of context datapoints. The model encodes the context dataset into a constant number of latent vectors on which self-attention is performed. When making predictions, the model retrieves higher-order information from the context dataset via multiple cross-attention mechanisms on the latent vectors. We empirically show that LBANPs achieve results competitive with the state-of-the-art on meta-regression, image completion, and contextual multi-armed bandits. We demonstrate that LBANPs can trade-off the computational cost and performance according to the number of latent vectors. Finally, we show LBANPs can scale beyond existing attention-based NP variants to larger dataset settings.

Leo Feng, Hossein Hajimirsadeghi, Yoshua Bengio, Mohamed Osama Ahmed• 2022

Related benchmarks

TaskDatasetResultRank
Regressionelevators (test)--
19
RegressionProtein (test)
Test Log Likelihood-0.0239
18
RegressionSkillcraft (test)
Log Likelihood (Test)-0.031
17
ForecastingHADISD Forecast (test)
Log-Likelihood0.2681
11
Regression1D GP (test)
Log-Likelihood0.004
11
InterpolationHADISD Interp (test)
Log-Likelihood-0.0904
11
RegressionTabular Synthetic (test)
Log-Likelihood0.096
10
RegressionPowerplant (test)
Log-Likelihood-0.0244
10
Forecasting24-hour window OOD (test)
Avg Test Log-Likelihood-0.0526
6
Regression1D GP synthetic (test)
Avg Test Log-Likelihood0.435
5
Showing 10 of 16 rows

Other info

Follow for update