M$^2$RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling
About
Transformers are highly parallel but are limited to computations in the TC$^0$ complexity class, excluding tasks such as entity tracking and code execution that provably require greater expressive power. Motivated by this limitation, we revisit non-linear Recurrent Neural Networks (RNNs) for language modeling and introduce Matrix-to-Matrix RNN (M$^2$RNN): an architecture with matrix-valued hidden states and expressive non-linear state transitions. We demonstrate that the language modeling performance of non-linear RNNs is limited by their state size. We also demonstrate how the state size expansion mechanism enables efficient use of tensor cores. Empirically, M$^2$RNN achieves perfect state tracking generalization at sequence lengths not seen during training. These benefits also translate to large-scale language modeling. In hybrid settings that interleave recurrent layers with attention, Hybrid M$^2$RNN outperforms equivalent Gated DeltaNet hybrids by $0.4$-$0.5$ perplexity points on a 7B MoE model, while using $3\times$ smaller state sizes for the recurrent layers. Notably, replacing even a single recurrent layer with M$^2$RNN in an existing hybrid architecture yields accuracy gains comparable to Hybrid M$^2$RNN with minimal impact on training throughput. Further, the Hybrid Gated DeltaNet models with a single M$^2$RNN layer also achieve superior long-context generalization, outperforming state-of-the-art hybrid linear attention architectures by up to $8$ points on LongBench. Together, these results establish non-linear RNN layers as a compelling building block for efficient and scalable language models.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | WinoGrande | Accuracy54.46 | 1085 | |
| Question Answering | ARC Challenge | Accuracy31.83 | 906 | |
| Commonsense Reasoning | PIQA | Accuracy71.6 | 751 | |
| Question Answering | ARC Easy | Accuracy64.14 | 597 | |
| Commonsense Reasoning | HellaSwag | HellaSwag Accuracy48.83 | 350 | |
| Question Answering | SciQ | Accuracy87.4 | 283 | |
| Language Modeling | LAMBADA | Accuracy37.18 | 268 | |
| Common Sense Reasoning | COPA | Accuracy72 | 197 | |
| Language Modeling | LAMBADA | Perplexity10.29 | 150 | |
| Language Modeling | WikiText | Wikitext PPL12.85 | 45 |