Scalable Gradients for Stochastic Differential Equations
About
The adjoint sensitivity method scalably computes gradients of solutions to ordinary differential equations. We generalize this method to stochastic differential equations, allowing time-efficient and constant-memory computation of gradients with high-order adaptive solvers. Specifically, we derive a stochastic differential equation whose solution is the gradient, a memory-efficient algorithm for caching noise, and conditions under which numerical solutions converge. In addition, we combine our method with gradient-based stochastic variational inference for latent stochastic differential equations. We use our method to fit stochastic dynamics defined by neural networks, achieving competitive performance on a 50-dimensional motion capture dataset.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| RUL prediction | N-CMAPSS | RMSE21.13 | 72 | |
| Remaining Useful Life Estimation | C-MAPSS FD002 (test) | RMSE20.28 | 44 | |
| Remaining Useful Life prediction | C-MAPSS FD004 (test) | RMSE21.55 | 24 | |
| Remaining Useful Life prediction | C-MAPSS FD001 (test) | RMSE20.57 | 24 | |
| Remaining Useful Life prediction | C-MAPSS FD003 (test) | RMSE21.13 | 24 | |
| Trajectory Inference | EB dataset 5D (test) | W1 (t=1)0.91 | 23 | |
| Time Series Forecasting | NDBC Wave-Height | MAE0.3526 | 18 | |
| Time Series Forecasting | XAU/USD | MAE0.0063 | 18 | |
| Forecasting | Synthetic partially observed jump-diffusion process (test) | MAE0.1118 | 11 | |
| Continuous sequence prediction | COVID-19 SIR dynamics in Japan (standard) | MSE1.086 | 8 |