Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs

About

Attention is the cornerstone of modern Large Language Models (LLMs). Yet its quadratic complexity hinders efficiency and scalability, especially for long-context processing. A promising approach is to leverage sparsity in attention. However, existing sparsity-based solutions predominantly rely on predefined patterns or heuristics at the attention head level, struggling to adapt dynamically to different contexts efficiently. We propose SeerAttention, a simple yet effective attention mechanism that directly learns the block-level attention sparsity from the LLM itself. Inspired by the gating mechanism in Mixture of Experts (MoE), SeerAttention augments the conventional attention with a learnable gate that selectively activates important blocks within the attention map. Specifically, the gate first pools the query (Q) and key (K) tensors along the sequence dimension and processes them through learnable linear layers. The resulting matrices are then multiplied together to produce the gating scores, which are used to predict block-level attention sparsity. Combined with our block-sparse FlashAttention kernel, SeerAttention can achieve significant speedup on GPUs. When applied to pre-trained LLMs, SeerAttention only requires training the gate parameters in a lightweight self-distillation manner, allowing rapid convergence. Our evaluation results demonstrate that SeerAttention achieves better model accuracy and lower latency for long-context pre-filling compared to prior methods. Code is available at: https://github.com/microsoft/SeerAttention

Yizhao Gao, Zhichen Zeng, Dayou Du, Shijie Cao, Peiyuan Zhou, Jiaxing Qi, Junjie Lai, Hayden Kwok-Hay So, Ting Cao, Fan Yang, Mao Yang• 2024

Related benchmarks

TaskDatasetResultRank
Long-context language modelingLongBench
Average Score30.8
328
Long-context Language UnderstandingLongBench
M-Avg41.57
294
Long-context evaluationRULER
Average Accuracy Score84
54
Latency MeasurementLLaMA-8B-Instruct Chunked Prefill 3.1 (inference)
Attention Latency (ms)1.00e+3
49
Long-context language modeling evaluationRULER
Score (4K)86.9
49
Long-context UnderstandingRULER 32k
Accuracy90.23
38
Long-context UnderstandingRULER 64k
Accuracy83.82
37
Long-context UnderstandingRULER 128k
Accuracy73.34
27
Long-context UnderstandingRULER--
27
Long-context language modelingRULER (test)
Sparsity77
13
Showing 10 of 13 rows

Other info

Follow for update