Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

StateX: Enhancing RNN Recall via Post-training State Expansion

About

Recurrent neural networks (RNNs), such as linear attention and state-space models, have gained popularity due to their constant per-token complexity when processing long contexts. However, these recurrent models struggle with tasks that require accurate recall of contextual information from long contexts, because all contextual information is compressed into a fixed-size recurrent state. Previous studies have shown that recall ability is positively correlated with the recurrent state size, yet directly training RNNs with large recurrent states results in high training costs. In this paper, we introduce StateX, a post-training framework that efficiently expands the states of pre-trained RNNs. For two popular classes of RNNs, linear attention and state-space models, we design post-training architectural modifications in StateX, to scale up the state size with no or negligible increase in model parameters. Experiments on models with up to 1.3B parameters demonstrate that StateX efficiently enhances the recall and in-context learning performance of RNNs without incurring high post-training costs or compromising other capabilities.

Xingyu Shen, Yingfa Chen, Zhen Leng Thai, Xu Han, Zhiyuan Liu, Maosong Sun• 2025

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag
Accuracy45
1891
Commonsense ReasoningWinoGrande
Accuracy59.9
1085
Commonsense ReasoningARC Challenge
Accuracy29.6
190
Commonsense ReasoningSIQA
Accuracy41.6
106
Common Sense ReasoningARC Easy
ARC (easy) Accuracy64
72
Common Sense ReasoningPIQA
Accuracy73.6
71
In-Context Learning12 Downstream Classification Tasks
Accuracy53
15
Passkey RetrievalPasskey Retrieval
Retrieval Success Rate (4K)93
11
Recall-intensive Question AnsweringRecall-intensive tasks (SWDE, SQuAD, TQA, NQ, Drop) truncated to 2K tokens (test)
SWDE56.1
5
Needle-In-A-Haystack RetrievalNIAH Single 2
Success Rate (4K Context)94
2
Showing 10 of 10 rows

Other info

Follow for update