Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers

About

Recurrent neural networks (RNNs), temporal convolutions, and neural differential equations (NDEs) are popular families of deep learning models for time-series data, each with unique strengths and tradeoffs in modeling power and computational efficiency. We introduce a simple sequence model inspired by control systems that generalizes these approaches while addressing their shortcomings. The Linear State-Space Layer (LSSL) maps a sequence $u \mapsto y$ by simply simulating a linear continuous-time state-space representation $\dot{x} = Ax + Bu, y = Cx + Du$. Theoretically, we show that LSSL models are closely related to the three aforementioned families of models and inherit their strengths. For example, they generalize convolutions to continuous-time, explain common RNN heuristics, and share features of NDEs such as time-scale adaptation. We then incorporate and generalize recent theory on continuous-time memorization to introduce a trainable subset of structured matrices $A$ that endow LSSLs with long-range memory. Empirically, stacking LSSL layers into a simple deep neural network obtains state-of-the-art results across time series benchmarks for long dependencies in sequential image classification, real-world healthcare regression tasks, and speech. On a difficult speech classification task with length-16000 sequences, LSSL outperforms prior approaches by 24 accuracy points, and even outperforms baselines that use hand-crafted features on 100x shorter sequences.

Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher R\'e• 2021

Related benchmarks

TaskDatasetResultRank
Time Series ForecastingETTm2
MSE0.243
382
Time Series ForecastingWeather
MSE0.174
223
Anomaly DetectionSMD
F1 Score71.31
217
Time Series ForecastingETTm1 (test)
MSE0.45
196
Time Series ForecastingExchange
MSE0.395
176
Anomaly DetectionSWaT
F1 Score85.76
174
Time Series ForecastingElectricity
MSE0.297
161
Time Series ForecastingTraffic
MSE0.798
145
Time Series ForecastingETTm2 (test)
MSE0.243
89
Pixel-by-pixel Image ClassificationPermuted Sequential MNIST (pMNIST) (test)
Accuracy98.76
79
Showing 10 of 29 rows

Other info

Code

Follow for update