Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Tensor Product Attention Is All You Need

About

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on language modeling tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at decoding stage enables processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models. Project Page: https://github.com/tensorgi/TPA.

Yifan Zhang, Yifeng Liu, Huizhuo Yuan, Zhen Qin, Yang Yuan, Quanquan Gu, Andrew Chi-Chih Yao• 2025

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag--
1891
Language ModelingC4 (val)
PPL16.622
514
Commonsense ReasoningWinoGrande
Accuracy57.85
372
Common Sense ReasoningBoolQ
Accuracy60.03
212
Commonsense ReasoningARC-C--
172
Language ModelingFineWeb (val)--
159
Commonsense ReasoningARC-E
Accuracy69.44
106
Common Sense ReasoningPIQA
Accuracy74.54
71
Commonsense ReasoningOpenBookQA
Accuracy41.6
71
Language ModelingThe Pile (val)
Perplexity (bits/byte)13.333
31
Showing 10 of 15 rows

Other info

Follow for update