Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Multi-Head Low-Rank Attention

About

Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8$\times$ decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.

Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo• 2026

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag--
1891
Language ModelingC4 (val)
PPL16.286
514
Commonsense ReasoningWinoGrande
Accuracy61.48
372
Common Sense ReasoningBoolQ
Accuracy61.74
212
Commonsense ReasoningARC-C--
172
Language ModelingFineWeb (val)--
159
Commonsense ReasoningARC-E
Accuracy67.89
106
Common Sense ReasoningPIQA
Accuracy75.52
71
Commonsense ReasoningOpenBookQA
Accuracy43
71
Language ModelingThe Pile (val)
Perplexity (bits/byte)13.124
31
Showing 10 of 15 rows

Other info

GitHub

Follow for update