Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Attention as an RNN

About

The advent of Transformers marked a significant breakthrough in sequence modelling, providing a highly performant architecture capable of leveraging GPU parallelism. However, Transformers are computationally expensive at inference time, limiting their applications, particularly in low-resource settings (e.g., mobile and embedded devices). Addressing this, we (1) begin by showing that attention can be viewed as a special Recurrent Neural Network (RNN) with the ability to compute its \textit{many-to-one} RNN output efficiently. We then (2) show that popular attention-based models such as Transformers can be viewed as RNN variants. However, unlike traditional RNNs (e.g., LSTMs), these models cannot be updated efficiently with new tokens, an important property in sequence modelling. Tackling this, we (3) introduce a new efficient method of computing attention's \textit{many-to-many} RNN output based on the parallel prefix scan algorithm. Building on the new attention formulation, we (4) introduce \textbf{Aaren}, an attention-based module that can not only (i) be trained in parallel (like Transformers) but also (ii) be updated efficiently with new tokens, requiring only constant memory for inferences (like traditional RNNs). Empirically, we show Aarens achieve comparable performance to Transformers on $38$ datasets spread across four popular sequential problem settings: reinforcement learning, event forecasting, time series classification, and time series forecasting tasks while being more time and memory-efficient.

Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Mohamed Osama Ahmed, Yoshua Bengio, Greg Mori• 2024

Related benchmarks

TaskDatasetResultRank
Time Series ForecastingETTh1
MSE0.59
729
Time Series ForecastingETTh2
MSE0.49
561
Time Series ForecastingETTm2
MSE0.34
382
Time Series ForecastingETTh1 (test)
MSE0.53
348
Time Series ForecastingETTm1
MSE0.51
334
Time Series ForecastingWeather
MSE0.25
295
Time Series ForecastingETTm1 (test)
MSE0.48
278
Time Series ForecastingTraffic (test)
MSE0.63
251
Time Series ForecastingETTh2 (test)
MSE0.38
232
Time Series ForecastingECL
MSE0.37
211
Showing 10 of 45 rows

Other info

Follow for update