Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning When to Attend: Conditional Memory Access for Long-Context LLMs

About

Language models struggle to generalize beyond pretraining context lengths, limiting long-horizon reasoning and retrieval. Continued pretraining on long-context data can help but is expensive due to the quadratic scaling of Attention. We observe that most tokens do not require (Global) Attention over the entire sequence and can rely on local context. Based on this, we propose L2A (Learning To Attend), a layer that enables conditional (token-wise) long-range memory access by deciding when to invoke global attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens. L2A matches the performance of standard long-context training to within 3% while skipping Global Attention for $\sim$80% of tokens, outperforming prior baselines. We also design custom Triton kernels to efficiently implement this token-wise conditional Attention on GPUs, achieving up to $\sim$2x improvements in training throughput and time-to-first-token over FlashAttention. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50% with negligible performance loss.

Sakshi Choudhary, Aditya Chattopadhyay, Luca Zancato, Elvis Nunez, Matthew Trager, Wei Xia, Stefano Soatto• 2026

Related benchmarks

TaskDatasetResultRank
Multi-round co-reference resolutionLong Context Benchmarks
Score (8k Context)36.3
21
Synthetic recallLong Context Benchmarks
Synthetic Recall (8k context)99.8
21
Many-shot in-context learningLong Context Benchmarks
ICL Performance (8k Context)71.4
21
Passage re-rankingLong Context Benchmarks
Performance (8k Context)49
21
Fact chaining & relational reasoningLong Context Benchmarks
Accuracy (8k Context)47.2
21
Retrieval-Augmented GenerationLong Context Benchmarks
RAG Score (8k Context)49.1
16
Generation w/ citationsGeneration w/ citations
Citation Quality (8k Context)30.4
13
Average across tasksAverage across tasks
Performance @ 8k Context Length55.7
13
Retrieval-Augmented GenerationRetrieval-Augmented Generation
Performance at 8k Context Length63.2
13
Average across tasksLong Context Benchmarks
Performance (8k Context)45.9
8
Showing 10 of 10 rows

Other info

Follow for update