GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
About
Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference. We (1) propose a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduce grouped-query attention (GQA), a generalization of multi-query attention which uses an intermediate (more than one, less than number of query heads) number of key-value heads. We show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.
Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebr\'on, Sumit Sanghai• 2023
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy39.4 | 1891 | |
| Question Answering | ARC Challenge | -- | 906 | |
| Commonsense Reasoning | PIQA | Accuracy67.5 | 751 | |
| Question Answering | ARC Easy | Accuracy42.2 | 597 | |
| Physical Commonsense Reasoning | PIQA | Accuracy73.5 | 572 | |
| Language Modeling | WikiText2 v1 (test) | Perplexity7.21 | 383 | |
| Commonsense Reasoning | HellaSwag | HellaSwag Accuracy58.6 | 350 | |
| Question Answering | SciQ | -- | 283 | |
| Language Modeling | LAMBADA | Accuracy37.4 | 268 | |
| Language Modeling | WikiText-103 (val) | PPL21.87 | 214 |
Showing 10 of 39 rows