Parallel Recursive LSTM
About
Transformers have become the dominant architecture for sequence modeling by using self-attention to enable expressive and highly parallel processing. However, the resulting quadratic time and memory costs limit efficiency in long-context settings. Recurrent models such as LSTMs provide explicit nonlinear state updates and strong state-tracking capabilities, yet their strictly sequential computation limits parallelism. We introduce the Parallel Recursive LSTM (PR-LSTM), a hierarchical recurrent architecture that replaces left-to-right recurrence with recursive nonlinear state composition over a balanced computation tree. Tokens are first mapped independently to latent states, which are then recursively merged by a learned gated composition block. This structure uses the reduction pattern underlying parallel scans as a fixed execution schedule, rather than assuming an associative recurrence. As a result, PR-LSTM retains nonlinear gated state representations while reducing recurrent parallel depth from linear to logarithmic. Empirically, PR-LSTM achieves strong sequence-length generalization on formal-language benchmarks, solving more tasks than standard RNN, LSTM, and Transformer baselines, while avoiding the quadratic scaling of attention. These results suggest that recurrent computation can be reorganized hierarchically to expose parallelism without restricting the transition dynamics to linear or associative forms.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Bucket Sort | Formal-language benchmark lengths 41-500 (test) | Accuracy99.4 | 6 | |
| Missing Duplicate | Formal-language benchmark lengths 41-500 (test) | Accuracy100 | 6 | |
| Binary Multiplication | Formal-language benchmark lengths 41-500 (test) | Accuracy (%)52.9 | 6 | |
| Compute Sqrt | Formal-language benchmark lengths 41-500 (test) | Accuracy (%)56.8 | 6 | |
| Duplicate String | Formal-language benchmark lengths 41-500 (test) | Accuracy54.7 | 6 | |
| Odds First | Formal-language benchmark lengths 41-500 (test) | Accuracy55 | 6 | |
| Binary Addition | Formal-language benchmark lengths 41-500 (test) | Accuracy51.8 | 6 | |
| Cycle Navigation | Formal-language benchmark lengths 41-500 (test) | Accuracy100 | 6 | |
| Even Pairs | Formal-language benchmark lengths 41-500 (test) | Accuracy (%)100 | 6 | |
| Modular Arithmetic | Formal-language benchmark lengths 41-500 (test) | Accuracy36.4 | 6 |