Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Gated Linear Attention Transformers with Hardware-Efficient Training

About

Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear-time inference complexity. However, linear attention generally underperforms ordinary softmax attention. Moreover, current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention. This work describes a hardware-efficient algorithm for linear attention that trades off memory movement against parallelizability. The resulting implementation, dubbed FLASHLINEARATTENTION, is faster than FLASHATTENTION-2 (Dao, 2023) as a standalone layer even on short sequence lengths (e.g., 1K). We then generalize this algorithm to a more expressive variant of linear attention with data-dependent gates. When used as a replacement for the standard attention layer in Transformers, the resulting gated linear attention (GLA) Transformer is found to perform competitively against the LLaMA-architecture Transformer (Touvron et al., 2023) as well recent linear-time-inference baselines such as RetNet (Sun et al., 2023a) and Mamba (Gu & Dao, 2023) on moderate-scale language modeling experiments. GLA Transformer is especially effective at length generalization, enabling a model trained on 2K to generalize to sequences longer than 20K without significant perplexity degradations. For training speed, the GLA Transformer has higher throughput than a similarly-sized Mamba model.

Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim• 2023

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag
Accuracy34.5
1891
Commonsense ReasoningWinoGrande
Accuracy51.4
1085
Multi-task Language UnderstandingMMLU
Accuracy22.9
876
Commonsense ReasoningPIQA
Accuracy64.8
751
Language ModelingWikiText
PPL41.47
732
Time Series ForecastingETTh1
MSE0.418
729
Question AnsweringARC Easy
Accuracy45.1
597
Time Series ForecastingETTh2
MSE0.342
561
Time Series ForecastingETTm2
MSE0.25
382
Question AnsweringBoolQ--
317
Showing 10 of 32 rows

Other info

Follow for update