Gated Linear Attention Transformers with Hardware-Efficient Training
About
Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear-time inference complexity. However, linear attention generally underperforms ordinary softmax attention. Moreover, current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention. This work describes a hardware-efficient algorithm for linear attention that trades off memory movement against parallelizability. The resulting implementation, dubbed FLASHLINEARATTENTION, is faster than FLASHATTENTION-2 (Dao, 2023) as a standalone layer even on short sequence lengths (e.g., 1K). We then generalize this algorithm to a more expressive variant of linear attention with data-dependent gates. When used as a replacement for the standard attention layer in Transformers, the resulting gated linear attention (GLA) Transformer is found to perform competitively against the LLaMA-architecture Transformer (Touvron et al., 2023) as well recent linear-time-inference baselines such as RetNet (Sun et al., 2023a) and Mamba (Gu & Dao, 2023) on moderate-scale language modeling experiments. GLA Transformer is especially effective at length generalization, enabling a model trained on 2K to generalize to sequences longer than 20K without significant perplexity degradations. For training speed, the GLA Transformer has higher throughput than a similarly-sized Mamba model.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy34.5 | 1460 | |
| Multi-task Language Understanding | MMLU | Accuracy22.9 | 842 | |
| Commonsense Reasoning | WinoGrande | Accuracy51.4 | 776 | |
| Commonsense Reasoning | PIQA | Accuracy64.8 | 647 | |
| Time Series Forecasting | ETTh1 | MSE0.418 | 601 | |
| Language Modeling | WikiText | PPL41.47 | 479 | |
| Time Series Forecasting | ETTh2 | MSE0.342 | 438 | |
| Question Answering | ARC Easy | Accuracy45.1 | 386 | |
| Time Series Forecasting | ETTm2 | MSE0.25 | 382 | |
| Question Answering | BoolQ | -- | 240 |