Scalable Spatiotemporal Inference with Biased Scan Attention Transformer Neural Processes
About
Neural Processes (NPs) are a rapidly evolving class of models designed to directly model the posterior predictive distribution of stochastic processes. While early architectures were developed primarily as a scalable alternative to Gaussian Processes (GPs), modern NPs tackle far more complex and data-hungry applications spanning geology, epidemiology, climate, and robotics. These applications have placed increasing pressure on the scalability of these models, with many architectures compromising accuracy for scalability. In this paper, we demonstrate that this trade-off is often unnecessary, particularly when modeling fully or partially translation-invariant processes. We propose a versatile new architecture, the Biased Scan Attention Transformer Neural Process (BSA-TNP), which introduces Kernel Regression Blocks (KRBlocks), group-invariant attention biases, and memory-efficient Biased Scan Attention (BSA). BSA-TNP is able to: (1) match or exceed the accuracy of the best models while often training in a fraction of the time, (2) exhibit translation invariance, enabling learning at multiple resolutions simultaneously, (3) transparently model processes that evolve in both space and time, (4) support high-dimensional fixed effects, and (5) scale gracefully, running inference on over 1M test points and 100K context points in under a minute on a single 24GB GPU. Code is provided as part of the `dl4bi` package.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Spatiotemporal forecasting | Beijing Multi-Site Air Quality UCI repository (test) | NLL-1.82 | 5 | |
| Spatiotemporal forecasting | Gneiting GP (test) | NLL1.38 | 5 | |
| Epidemiological Spatiotemporal Inference | SIR Shifted domain (test) | NLL0.19 | 4 | |
| Epidemiological Spatiotemporal Inference | SIR Scaled domain (test) | NLL0.18 | 4 | |
| 2D Gaussian Process Regression | 2D GP (Shifted) | NLL0.32 | 4 | |
| 2D Gaussian Process Regression | 2D GP Scaled | NLL0.28 | 4 | |
| Epidemiological Spatiotemporal Inference | SIR Original domain (test) | NLL0.19 | 4 | |
| 2D Gaussian Process Regression | 2D GP Original | NLL0.32 | 4 | |
| Spatiotemporal forecasting | ERA5 western Europe CNW (test) | NLL0.07 | 3 | |
| Climate Forecasting | ERA5 CWN (train-Central, val-Western, test-Northern) | NLL0.13 | 3 |