Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression

About

Linear State-Space Models (SSMs) offer an efficient alternative to softmax Attention with constant memory and linear compute, but their lossy, fading summary of the past hurts recall-oriented tasks. We propose Gated KalmaNet (GKA, pronounced "gee-ka"), a layer that accounts for the full past while retaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF), and show that several existing SSM layers (DeltaNet, Gated DeltaNet, Kimi Delta Attention) are approximations to the KF recurrence under an identity error covariance assumption, which ignores how past keys and values should optimally influence state updates. In contrast, GKA maintains the full error covariance and computes the exact Kalman gain. Under a steady-state assumption that enables parallelization, this reduces to an online ridge regression with constant memory and linear compute. The standard KF equations are numerically unstable in low-precision settings (e.g., bfloat16) and hard to parallelize on GPUs. We address this with (1) adaptive regularization via input-dependent gating to control the ridge regression's condition number, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low precision. We further develop hardware-aware chunk-wise kernels for efficient training. Empirically, GKA outperforms existing SSM layers (e.g., Mamba2, Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA up to 128k tokens. We further show GKA outperforms Mamba when extended to ImageNet classification. Our code, including Triton kernels for training and inference (vLLM), along with a model zoo of GKA-based Hybrid models at 8B and 32B scale on HuggingFace, is released under Apache 2.0.

Liangzu Peng, Aditya Chattopadhyay, Luca Zancato, Elvis Nunez, Wei Xia, Stefano Soatto• 2025

Related benchmarks

TaskDatasetResultRank
Question AnsweringARC-E--
523
Common Sense ReasoningCOPA
Accuracy85
256
Commonsense ReasoningPIQA
Accuracy74.81
213
Question AnsweringBoolQ
Accuracy61.68
201
Question AnsweringARC-C
Accuracy (ARC-C)32.51
36
Commonsense ReasoningWinoGrande
Accuracy64.17
23
Language ModelingLM-Harness
ARC-C Accuracy32.51
17
Question AnsweringSciQ
Normalized Accuracy83.2
14
Throughput MeasurementSystem Throughput Evaluation (test)
Throughput (tokens/GPU/sec)6.90e+3
12
Commonsense ReasoningHellaSwag
HellaSWAG Normalized Accuracy63.84
6
Showing 10 of 10 rows

Other info

Follow for update