Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Gated Linear Attention Transformers with Hardware-Efficient Training

About

Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear-time inference complexity. However, linear attention generally underperforms ordinary softmax attention. Moreover, current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention. This work describes a hardware-efficient algorithm for linear attention that trades off memory movement against parallelizability. The resulting implementation, dubbed FLASHLINEARATTENTION, is faster than FLASHATTENTION-2 (Dao, 2023) as a standalone layer even on short sequence lengths (e.g., 1K). We then generalize this algorithm to a more expressive variant of linear attention with data-dependent gates. When used as a replacement for the standard attention layer in Transformers, the resulting gated linear attention (GLA) Transformer is found to perform competitively against the LLaMA-architecture Transformer (Touvron et al., 2023) as well recent linear-time-inference baselines such as RetNet (Sun et al., 2023a) and Mamba (Gu & Dao, 2023) on moderate-scale language modeling experiments. GLA Transformer is especially effective at length generalization, enabling a model trained on 2K to generalize to sequences longer than 20K without significant perplexity degradations. For training speed, the GLA Transformer has higher throughput than a similarly-sized Mamba model.

Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim• 2023

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag
Accuracy34.5
1460
Multi-task Language UnderstandingMMLU
Accuracy22.9
842
Commonsense ReasoningWinoGrande
Accuracy51.4
776
Commonsense ReasoningPIQA
Accuracy64.8
647
Time Series ForecastingETTh1
MSE0.418
601
Language ModelingWikiText
PPL41.47
479
Time Series ForecastingETTh2
MSE0.342
438
Question AnsweringARC Easy
Accuracy45.1
386
Time Series ForecastingETTm2
MSE0.25
382
Question AnsweringBoolQ--
240
Showing 10 of 29 rows

Other info

Follow for update