Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Parallelizing Linear Transformers with the Delta Rule over Sequence Length

About

Transformers with linear attention (i.e., linear transformers) and state-space models have recently been suggested as a viable linear-time alternative to transformers with softmax attention. However, these models still underperform transformers especially on tasks that require in-context retrieval. While more expressive variants of linear transformers which replace the additive update in linear transformers with the delta rule (DeltaNet) have been found to be more effective at associative recall, existing algorithms for training such models do not parallelize over sequence length and are thus inefficient to train on modern hardware. This work describes a hardware-efficient algorithm for training linear transformers with the delta rule, which exploits a memory-efficient representation for computing products of Householder matrices. This algorithm allows us to scale up DeltaNet to standard language modeling settings. We train a 1.3B model for 100B tokens and find that it outperforms recent linear-time baselines such as Mamba and GLA in terms of perplexity and zero-shot performance on downstream tasks. We also experiment with two hybrid models which combine DeltaNet layers with (1) sliding-window attention layers every other layer or (2) two global attention layers, and find that these hybrids outperform strong transformer baselines.

Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, Yoon Kim• 2024

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningHellaSwag
Accuracy44.5
1891
Commonsense ReasoningWinoGrande
Accuracy54.7
1085
Commonsense ReasoningPIQA
Accuracy70.95
751
Language ModelingWikiText
PPL18.38
732
Question AnsweringARC Easy--
597
Commonsense ReasoningHellaSwag
HellaSwag Accuracy51.09
350
Question AnsweringBoolQ--
317
Question AnsweringSciQ
Accuracy82.6
283
Language ModelingLAMBADA
Accuracy41.8
268
Common Sense ReasoningBoolQ
Accuracy61.19
212
Showing 10 of 52 rows

Other info

Code

Follow for update