PredRNN++: Towards A Resolution of the Deep-in-Time Dilemma in Spatiotemporal Predictive Learning
About
We present PredRNN++, an improved recurrent network for video predictive learning. In pursuit of a greater spatiotemporal modeling capability, our approach increases the transition depth between adjacent states by leveraging a novel recurrent unit, which is named Causal LSTM for re-organizing the spatial and temporal memories in a cascaded mechanism. However, there is still a dilemma in video predictive learning: increasingly deep-in-time models have been designed for capturing complex variations, while introducing more difficulties in the gradient back-propagation. To alleviate this undesirable effect, we propose a Gradient Highway architecture, which provides alternative shorter routes for gradient flows from outputs back to long-range inputs. This architecture works seamlessly with causal LSTMs, enabling PredRNN++ to capture short-term and long-term dependencies adaptively. We assess our model on both synthetic and real video datasets, showing its ability to ease the vanishing gradient problem and yield state-of-the-art prediction results even in a difficult objects occlusion scenario.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Video Prediction | KTH 10 -> 20 steps (test) | PSNR28.62 | 88 | |
| Video Prediction | Moving MNIST (test) | MSE46.5 | 82 | |
| Video Prediction | KTH 10 -> 40 steps (test) | PSNR26.94 | 77 | |
| Video Prediction | Moving MNIST | SSIM0.898 | 52 | |
| Human Motion Prediction | Human3.6M | -- | 46 | |
| Video Prediction | Moving-MNIST 10 → 10 (test) | MSE22.45 | 39 | |
| Video Prediction | KTH | PSNR28.13 | 35 | |
| Video Prediction | UCF Sports t+1 (test) | PSNR27.26 | 32 | |
| Spatio-temporal forecasting | TaxiBJ | MSE0.3348 | 30 | |
| Traffic Forecasting | TaxiBJ (test) | MAE16.9 | 29 |