Tensor Product Attention Is All You Need
About
Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on language modeling tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at decoding stage enables processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models. Project Page: https://github.com/tensorgi/TPA.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | -- | 1891 | |
| Language Modeling | C4 (val) | PPL16.622 | 514 | |
| Commonsense Reasoning | WinoGrande | Accuracy57.85 | 372 | |
| Common Sense Reasoning | BoolQ | Accuracy60.03 | 212 | |
| Commonsense Reasoning | ARC-C | -- | 172 | |
| Language Modeling | FineWeb (val) | -- | 159 | |
| Commonsense Reasoning | ARC-E | Accuracy69.44 | 106 | |
| Common Sense Reasoning | PIQA | Accuracy74.54 | 71 | |
| Commonsense Reasoning | OpenBookQA | Accuracy41.6 | 71 | |
| Language Modeling | The Pile (val) | Perplexity (bits/byte)13.333 | 31 |