Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Scalable Spatiotemporal Inference with Biased Scan Attention Transformer Neural Processes

About

Neural Processes (NPs) are a rapidly evolving class of models designed to directly model the posterior predictive distribution of stochastic processes. While early architectures were developed primarily as a scalable alternative to Gaussian Processes (GPs), modern NPs tackle far more complex and data-hungry applications spanning geology, epidemiology, climate, and robotics. These applications have placed increasing pressure on the scalability of these models, with many architectures compromising accuracy for scalability. In this paper, we demonstrate that this trade-off is often unnecessary, particularly when modeling fully or partially translation-invariant processes. We propose a versatile new architecture, the Biased Scan Attention Transformer Neural Process (BSA-TNP), which introduces Kernel Regression Blocks (KRBlocks), group-invariant attention biases, and memory-efficient Biased Scan Attention (BSA). BSA-TNP is able to: (1) match or exceed the accuracy of the best models while often training in a fraction of the time, (2) exhibit translation invariance, enabling learning at multiple resolutions simultaneously, (3) transparently model processes that evolve in both space and time, (4) support high-dimensional fixed effects, and (5) scale gracefully, running inference on over 1M test points and 100K context points in under a minute on a single 24GB GPU. Code is provided as part of the `dl4bi` package.

Daniel Jenson, Jhonathan Navott, Piotr Grynfelder, Mengyan Zhang, Makkunda Sharma, Elizaveta Semenova, Seth Flaxman• 2025

Related benchmarks

TaskDatasetResultRank
Spatiotemporal forecastingBeijing Multi-Site Air Quality UCI repository (test)
NLL-1.82
5
Spatiotemporal forecastingGneiting GP (test)
NLL1.38
5
Epidemiological Spatiotemporal InferenceSIR Shifted domain (test)
NLL0.19
4
Epidemiological Spatiotemporal InferenceSIR Scaled domain (test)
NLL0.18
4
2D Gaussian Process Regression2D GP (Shifted)
NLL0.32
4
2D Gaussian Process Regression2D GP Scaled
NLL0.28
4
Epidemiological Spatiotemporal InferenceSIR Original domain (test)
NLL0.19
4
2D Gaussian Process Regression2D GP Original
NLL0.32
4
Spatiotemporal forecastingERA5 western Europe CNW (test)
NLL0.07
3
Climate ForecastingERA5 CWN (train-Central, val-Western, test-Northern)
NLL0.13
3
Showing 10 of 10 rows

Other info

Follow for update