Latent Bottlenecked Attentive Neural Processes
About
Neural Processes (NPs) are popular methods in meta-learning that can estimate predictive uncertainty on target datapoints by conditioning on a context dataset. Previous state-of-the-art method Transformer Neural Processes (TNPs) achieve strong performance but require quadratic computation with respect to the number of context datapoints, significantly limiting its scalability. Conversely, existing sub-quadratic NP variants perform significantly worse than that of TNPs. Tackling this issue, we propose Latent Bottlenecked Attentive Neural Processes (LBANPs), a new computationally efficient sub-quadratic NP variant, that has a querying computational complexity independent of the number of context datapoints. The model encodes the context dataset into a constant number of latent vectors on which self-attention is performed. When making predictions, the model retrieves higher-order information from the context dataset via multiple cross-attention mechanisms on the latent vectors. We empirically show that LBANPs achieve results competitive with the state-of-the-art on meta-regression, image completion, and contextual multi-armed bandits. We demonstrate that LBANPs can trade-off the computational cost and performance according to the number of latent vectors. Finally, we show LBANPs can scale beyond existing attention-based NP variants to larger dataset settings.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Regression | elevators (test) | -- | 19 | |
| Regression | Protein (test) | Test Log Likelihood-0.0239 | 18 | |
| Regression | Skillcraft (test) | Log Likelihood (Test)-0.031 | 17 | |
| Forecasting | HADISD Forecast (test) | Log-Likelihood0.2681 | 11 | |
| Regression | 1D GP (test) | Log-Likelihood0.004 | 11 | |
| Interpolation | HADISD Interp (test) | Log-Likelihood-0.0904 | 11 | |
| Regression | Tabular Synthetic (test) | Log-Likelihood0.096 | 10 | |
| Regression | Powerplant (test) | Log-Likelihood-0.0244 | 10 | |
| Forecasting | 24-hour window OOD (test) | Avg Test Log-Likelihood-0.0526 | 6 | |
| Regression | 1D GP synthetic (test) | Avg Test Log-Likelihood0.435 | 5 |