StateX: Enhancing RNN Recall via Post-training State Expansion
About
Recurrent neural networks (RNNs), such as linear attention and state-space models, have gained popularity due to their constant per-token complexity when processing long contexts. However, these recurrent models struggle with tasks that require accurate recall of contextual information from long contexts, because all contextual information is compressed into a fixed-size recurrent state. Previous studies have shown that recall ability is positively correlated with the recurrent state size, yet directly training RNNs with large recurrent states results in high training costs. In this paper, we introduce StateX, a post-training framework that efficiently expands the states of pre-trained RNNs. For two popular classes of RNNs, linear attention and state-space models, we design post-training architectural modifications in StateX, to scale up the state size with no or negligible increase in model parameters. Experiments on models with up to 1.3B parameters demonstrate that StateX efficiently enhances the recall and in-context learning performance of RNNs without incurring high post-training costs or compromising other capabilities.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy45 | 1891 | |
| Commonsense Reasoning | WinoGrande | Accuracy59.9 | 1085 | |
| Commonsense Reasoning | ARC Challenge | Accuracy29.6 | 190 | |
| Commonsense Reasoning | SIQA | Accuracy41.6 | 106 | |
| Common Sense Reasoning | ARC Easy | ARC (easy) Accuracy64 | 72 | |
| Common Sense Reasoning | PIQA | Accuracy73.6 | 71 | |
| In-Context Learning | 12 Downstream Classification Tasks | Accuracy53 | 15 | |
| Passkey Retrieval | Passkey Retrieval | Retrieval Success Rate (4K)93 | 11 | |
| Recall-intensive Question Answering | Recall-intensive tasks (SWDE, SQuAD, TQA, NQ, Drop) truncated to 2K tokens (test) | SWDE56.1 | 5 | |
| Needle-In-A-Haystack Retrieval | NIAH Single 2 | Success Rate (4K Context)94 | 2 |