CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting
About
Deep learning has been actively studied for time series forecasting, and the mainstream paradigm is based on the end-to-end training of neural network architectures, ranging from classical LSTM/RNNs to more recent TCNs and Transformers. Motivated by the recent success of representation learning in computer vision and natural language processing, we argue that a more promising paradigm for time series forecasting, is to first learn disentangled feature representations, followed by a simple regression fine-tuning step -- we justify such a paradigm from a causal perspective. Following this principle, we propose a new time series representation learning framework for time series forecasting named CoST, which applies contrastive learning methods to learn disentangled seasonal-trend representations. CoST comprises both time domain and frequency domain contrastive losses to learn discriminative trend and seasonal representations, respectively. Extensive experiments on real-world datasets show that CoST consistently outperforms the state-of-the-art methods by a considerable margin, achieving a 21.3% improvement in MSE on multivariate benchmarks. It is also robust to various choices of backbone encoders, as well as downstream regressors. Code is available at https://github.com/salesforce/CoST.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Multivariate long-term forecasting | ETTh1 | MSE0.71 | 344 | |
| Time Series Forecasting | ETTm1 | MSE0.059 | 334 | |
| Multivariate long-term series forecasting | ETTh2 | MSE1.664 | 319 | |
| Multivariate long-term series forecasting | Weather | MSE1.111 | 288 | |
| Multivariate long-term series forecasting | ETTm1 | MSE0.477 | 257 | |
| Multivariate long-term forecasting | Electricity | MSE0.228 | 183 | |
| Time Series Forecasting | Exchange | MSE0.054 | 176 | |
| Multivariate long-term series forecasting | ETTm2 | MSE0.825 | 175 | |
| Multivariate long-term forecasting | Traffic | MSE0.76 | 159 | |
| Multivariate long-term forecasting | ETTm1 (test) | MSE0.253 | 134 |