Translation Equivariant Transformer Neural Processes
About
The effectiveness of neural processes (NPs) in modelling posterior prediction maps -- the mapping from data to posterior predictive distributions -- has significantly improved since their inception. This improvement can be attributed to two principal factors: (1) advancements in the architecture of permutation invariant set functions, which are intrinsic to all NPs; and (2) leveraging symmetries present in the true posterior predictive map, which are problem dependent. Transformers are a notable development in permutation invariant set functions, and their utility within NPs has been demonstrated through the family of models we refer to as TNPs. Despite significant interest in TNPs, little attention has been given to incorporating symmetries. Notably, the posterior prediction maps for data that are stationary -- a common assumption in spatio-temporal modelling -- exhibit translation equivariance. In this paper, we introduce of a new family of translation equivariant TNPs that incorporate translation equivariance. Through an extensive range of experiments on synthetic and real-world spatio-temporal data, we demonstrate the effectiveness of TE-TNPs relative to their non-translation-equivariant counterparts and other NP baselines.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Spatiotemporal forecasting | Beijing Multi-Site Air Quality UCI repository (test) | NLL-1.6 | 5 | |
| Spatiotemporal forecasting | Gneiting GP (test) | NLL1.46 | 5 | |
| 2D Gaussian Process Regression | 2D GP (Shifted) | NLL0.4 | 4 | |
| Epidemiological Spatiotemporal Inference | SIR Shifted domain (test) | NLL0.27 | 4 | |
| 2D Gaussian Process Regression | 2D GP Original | NLL0.4 | 4 | |
| 2D Gaussian Process Regression | 2D GP Scaled | NLL1.45 | 4 | |
| Epidemiological Spatiotemporal Inference | SIR Original domain (test) | NLL0.27 | 4 | |
| Epidemiological Spatiotemporal Inference | SIR Scaled domain (test) | NLL0.44 | 4 | |
| Climate Forecasting | ERA5 CWN (train-Central, val-Western, test-Northern) | NLL0.13 | 3 | |
| Spatiotemporal forecasting | ERA5 western Europe CNW (test) | NLL0.17 | 3 |