Long-Context Generalization with Sparse Attention
About
Transformer-based architectures traditionally employ softmax to compute attention weights, which produces dense distributions over all tokens in a sequence. While effective in many settings, this density has been shown to be detrimental for tasks that demand precise focus on fixed-size patterns: as sequence length increases, non-informative tokens accumulate attention probability mass, leading to dispersion and representational collapse. We show in this paper that dynamically sparse attention mechanisms using $\alpha$-entmax can avoid these issues, due to their ability to assign exact zeros to irrelevant tokens. Furthermore, we introduce Adaptive-Scalable Entmax (ASEntmax), which endows $\alpha$-entmax with a learnable temperature parameter, allowing the attention distribution to interpolate between sparse (pattern-focused) and dense (softmax-like) regimes. Our empirical evaluation on synthetic tasks and language modeling demonstrates that ASEntmax substantially outperforms softmax, scalable softmax, and fixed-temperature $\alpha$-entmax baselines, achieving up to 1000$\times$ length extrapolation on synthetic benchmarks and superior long-context generalization on language modeling while preserving short-context performance, including better perplexity trends and higher retrieval accuracies at 8$\times$ training length.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy33.4 | 1891 | |
| Commonsense Reasoning | WinoGrande | Accuracy50 | 1085 | |
| Commonsense Reasoning | PIQA | Accuracy63.8 | 751 | |
| Language Modeling | LAMBADA | Accuracy34.3 | 268 | |
| Language Modeling | Arxiv Proof-pile | Perplexity16.86 | 40 | |
| Language Modeling | Pubmed | Perplexity17.69 | 38 | |
| Copy | Copy OOD lengths: 2x, 4x, 8x, 16x, 32x, 64x | Exact Match Accuracy100 | 30 | |
| MQMTAR | MQMTAR OOD lengths 2x 4x 16x 64x 256x 1024x | Exact Match Accuracy100 | 30 | |
| Reverse | Reverse OOD lengths: 1.5x, 2x, 4x, 8x | Exact Match Accuracy100 | 20 | |
| Sort | Sort OOD lengths: 2x, 4x, 8x | Exact Match Accuracy100 | 15 |