Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Attention as an RNN

About

The advent of Transformers marked a significant breakthrough in sequence modelling, providing a highly performant architecture capable of leveraging GPU parallelism. However, Transformers are computationally expensive at inference time, limiting their applications, particularly in low-resource settings (e.g., mobile and embedded devices). Addressing this, we (1) begin by showing that attention can be viewed as a special Recurrent Neural Network (RNN) with the ability to compute its \textit{many-to-one} RNN output efficiently. We then (2) show that popular attention-based models such as Transformers can be viewed as RNN variants. However, unlike traditional RNNs (e.g., LSTMs), these models cannot be updated efficiently with new tokens, an important property in sequence modelling. Tackling this, we (3) introduce a new efficient method of computing attention's \textit{many-to-many} RNN output based on the parallel prefix scan algorithm. Building on the new attention formulation, we (4) introduce \textbf{Aaren}, an attention-based module that can not only (i) be trained in parallel (like Transformers) but also (ii) be updated efficiently with new tokens, requiring only constant memory for inferences (like traditional RNNs). Empirically, we show Aarens achieve comparable performance to Transformers on $38$ datasets spread across four popular sequential problem settings: reinforcement learning, event forecasting, time series classification, and time series forecasting tasks while being more time and memory-efficient.

Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Mohamed Osama Ahmed, Yoshua Bengio, Greg Mori• 2024

Related benchmarks

TaskDatasetResultRank
Time Series ForecastingETTh1
MSE0.59
601
Time Series ForecastingETTh2
MSE0.49
438
Time Series ForecastingETTm2
MSE0.34
382
Time Series ForecastingETTm1
MSE0.51
334
Time Series ForecastingETTh1 (test)
MSE0.53
262
Time Series ForecastingWeather
MSE0.25
223
Time Series ForecastingETTm1 (test)
MSE0.48
196
Time Series ForecastingTraffic (test)
MSE0.63
192
Time Series ForecastingECL
MSE0.37
183
Time Series ForecastingExchange
MSE0.25
176
Showing 10 of 45 rows

Other info

Follow for update