Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression
About
Linear State-Space Models (SSMs) offer an efficient alternative to softmax Attention with constant memory and linear compute, but their lossy, fading summary of the past hurts recall-oriented tasks. We propose Gated KalmaNet (GKA, pronounced "gee-ka"), a layer that accounts for the full past while retaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF), and show that several existing SSM layers (DeltaNet, Gated DeltaNet, Kimi Delta Attention) are approximations to the KF recurrence under an identity error covariance assumption, which ignores how past keys and values should optimally influence state updates. In contrast, GKA maintains the full error covariance and computes the exact Kalman gain. Under a steady-state assumption that enables parallelization, this reduces to an online ridge regression with constant memory and linear compute. The standard KF equations are numerically unstable in low-precision settings (e.g., bfloat16) and hard to parallelize on GPUs. We address this with (1) adaptive regularization via input-dependent gating to control the ridge regression's condition number, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low precision. We further develop hardware-aware chunk-wise kernels for efficient training. Empirically, GKA outperforms existing SSM layers (e.g., Mamba2, Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA up to 128k tokens. We further show GKA outperforms Mamba when extended to ImageNet classification. Our code, including Triton kernels for training and inference (vLLM), along with a model zoo of GKA-based Hybrid models at 8B and 32B scale on HuggingFace, is released under Apache 2.0.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Question Answering | ARC-E | -- | 523 | |
| Common Sense Reasoning | COPA | Accuracy85 | 256 | |
| Commonsense Reasoning | PIQA | Accuracy74.81 | 213 | |
| Question Answering | BoolQ | Accuracy61.68 | 201 | |
| Question Answering | ARC-C | Accuracy (ARC-C)32.51 | 36 | |
| Commonsense Reasoning | WinoGrande | Accuracy64.17 | 23 | |
| Language Modeling | LM-Harness | ARC-C Accuracy32.51 | 17 | |
| Question Answering | SciQ | Normalized Accuracy83.2 | 14 | |
| Throughput Measurement | System Throughput Evaluation (test) | Throughput (tokens/GPU/sec)6.90e+3 | 12 | |
| Commonsense Reasoning | HellaSwag | HellaSWAG Normalized Accuracy63.84 | 6 |