Training-Inference Consistent Segmented Execution for Long-Context LLMs
About
Transformer-based large language models face severe scalability challenges in long-context generation due to the computational and memory costs of full-context attention. Under practical computation and memory constraints, many inference-efficient long-context methods improve efficiency by adopting bounded-context or segment-level execution only during inference, while continuing to train models under full-context attention, resulting in a mismatch between training and inference execution and state-transition semantics. Based on this insight, we propose a training-inference consistent segment-level generation framework, in which training and inference follow the same segment-level forward execution semantics. During training, consistency with inference is enforced by restricting gradient propagation to KV states carried over from the immediately preceding segment, while permitting head-specific access to past KV states during the forward pass without involving them in gradient propagation. Across long-context benchmarks, our approach achieves performance comparable to full-context attention, while achieving competitive latency-memory trade-offs against strong inference-efficient baselines, and substantially improving scalability at very long context lengths (e.g., approximately 6x lower peak prefill memory at 128K compared to full-context attention with FlashAttention).
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Long-context Language Understanding | LongBench v2 | Overall Accuracy29.8 | 62 | |
| Long-context Language Understanding | RULER 32k context length | FWE38.17 | 39 | |
| Long-context Language Understanding | RULER 64k context length | FWE (Error)34.17 | 22 | |
| Long-context Language Understanding | RULER 16k context length | FWE Score44.83 | 21 | |
| Long-context Language Understanding | RULER 4k context length | FWE Rate53.83 | 16 | |
| Long-context Understanding | RULER 8k context | CWE54.15 | 13 | |
| Long-context Language Understanding | LongBench-E 2024 (test) | Short Context QA Score7.58 | 12 | |
| Long-context Information Extraction | RULER 4K-32K Average | CWE Score46.39 | 6 | |
| Long-context Language Understanding | LongBench (standard) | NQA5.89 | 6 |