Scaling Attention via Feature Sparsity
About
Scaling Transformers to ultra-long contexts is bottlenecked by the $O(n^2 d)$ cost of self-attention. Existing methods reduce this cost along the sequence axis through local windows, kernel approximations, or token-level sparsity, but these approaches consistently degrade accuracy. In this paper, we instead explore an orthogonal axis: feature sparsity. We propose Sparse Feature Attention (SFA), where queries and keys are represented as $k$-sparse codes that preserve high-dimensional expressivity while reducing the cost of attention from $\Theta(n^2 d)$ to $\Theta(n^2 k^2/d)$. To make this efficient at scale, we introduce FlashSFA, an IO-aware kernel that extends FlashAttention to operate directly on sparse overlaps without materializing dense score matrices. Across GPT-2 and Qwen3 pretraining, SFA matches dense baselines while improving speed by up to $2.5\times$ and reducing FLOPs and KV-cache by nearly 50\%. On synthetic and downstream benchmarks, SFA preserves retrieval accuracy and robustness at long contexts, outperforming short-embedding baselines that collapse feature diversity. These results establish feature-level sparsity as a complementary and underexplored axis for efficient attention, enabling Transformers to scale to orders-of-magnitude longer contexts with minimal quality loss. Code is available at https://github.com/YannX1e/Sparse-Feature-Attention.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | The Pile | Perplexity4.81 | 94 | |
| Language Modeling | OpenWebText | Perplexity16.78 | 91 | |
| Zero-shot Reasoning | Reasoning Suite PiQA, LAMBDA, ARC, HellaSwag | PiQA Score61.73 | 20 | |
| Mathematical Reasoning | GSM-8K | Accuracy89.11 | 16 | |
| Language Modeling | OWT Pile | Decode Latency5.23 | 12 | |
| Zero-shot Classification | Evaluation Suite Zero-shot (PiQA, LAMBDA, ARC-e, ARC-c, HellaS) | Decode Latency5.23 | 12 | |
| Document Question Answering | Arxiv Sci-papers | Accuracy54.26 | 9 | |
| Document Question Answering | PubMed Sci-papers | Accuracy55.07 | 9 | |
| Synthetic Retrieval | NIAH | Score @ 4096 Context100 | 9 | |
| Zero-shot Evaluation | Evaluation Suite (PiQA, LAMBADA, ARC, HellaSwag) zero-shot | PiQA Accuracy61.73 | 9 |