Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

M$^2$RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling

About

Transformers are highly parallel but are limited to computations in the TC$^0$ complexity class, excluding tasks such as entity tracking and code execution that provably require greater expressive power. Motivated by this limitation, we revisit non-linear Recurrent Neural Networks (RNNs) for language modeling and introduce Matrix-to-Matrix RNN (M$^2$RNN): an architecture with matrix-valued hidden states and expressive non-linear state transitions. We demonstrate that the language modeling performance of non-linear RNNs is limited by their state size. We also demonstrate how the state size expansion mechanism enables efficient use of tensor cores. Empirically, M$^2$RNN achieves perfect state tracking generalization at sequence lengths not seen during training. These benefits also translate to large-scale language modeling. In hybrid settings that interleave recurrent layers with attention, Hybrid M$^2$RNN outperforms equivalent Gated DeltaNet hybrids by $0.4$-$0.5$ perplexity points on a 7B MoE model, while using $3\times$ smaller state sizes for the recurrent layers. Notably, replacing even a single recurrent layer with M$^2$RNN in an existing hybrid architecture yields accuracy gains comparable to Hybrid M$^2$RNN with minimal impact on training throughput. Further, the Hybrid Gated DeltaNet models with a single M$^2$RNN layer also achieve superior long-context generalization, outperforming state-of-the-art hybrid linear attention architectures by up to $8$ points on LongBench. Together, these results establish non-linear RNN layers as a compelling building block for efficient and scalable language models.

Mayank Mishra, Shawn Tan, Ion Stoica, Joseph Gonzalez, Tri Dao• 2026

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningWinoGrande
Accuracy54.46
1085
Question AnsweringARC Challenge
Accuracy31.83
906
Commonsense ReasoningPIQA
Accuracy71.6
751
Question AnsweringARC Easy
Accuracy64.14
597
Commonsense ReasoningHellaSwag
HellaSwag Accuracy48.83
350
Question AnsweringSciQ
Accuracy87.4
283
Language ModelingLAMBADA
Accuracy37.18
268
Common Sense ReasoningCOPA
Accuracy72
197
Language ModelingLAMBADA
Perplexity10.29
150
Language ModelingWikiText
Wikitext PPL12.85
45
Showing 10 of 20 rows

Other info

Follow for update