Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Trajectory Flow Matching with Applications to Clinical Time Series Modeling

About

Modeling stochastic and irregularly sampled time series is a challenging problem found in a wide range of applications, especially in medicine. Neural stochastic differential equations (Neural SDEs) are an attractive modeling technique for this problem, which parameterize the drift and diffusion terms of an SDE with neural networks. However, current algorithms for training Neural SDEs require backpropagation through the SDE dynamics, greatly limiting their scalability and stability. To address this, we propose Trajectory Flow Matching (TFM), which trains a Neural SDE in a simulation-free manner, bypassing backpropagation through the dynamics. TFM leverages the flow matching technique from generative modeling to model time series. In this work we first establish necessary conditions for TFM to learn time series data. Next, we present a reparameterization trick which improves training stability. Finally, we adapt TFM to the clinical time series setting, demonstrating improved performance on three clinical time series datasets both in terms of absolute performance and uncertainty prediction.

Xi Zhang, Yuan Pu, Yuki Kawamura, Andrew Loza, Yoshua Bengio, Dennis L. Shung, Alexander Tong• 2024

Related benchmarks

TaskDatasetResultRank
Probabilistic time series forecastingWeather Regular (test)
Avg NCRPS0.912
11
Probabilistic time series forecastingETTm1 Irregular (test)
Avg NCRPS0.578
11
Probabilistic time series forecastingETTm2 Irregular (test)
Average NCRPS2.821
11
Probabilistic time series forecastingWeather Irregular (test)
Average NCRPS0.749
11
Probabilistic time series forecastingETTm1 Regular (test)
Avg NCRPS0.604
11
Probabilistic time series forecastingETTm2 Regular (test)
Avg NCRPS2.297
11
Probabilistic time series forecastingElectricity (test)
Average NCRPS0.302
10
Probabilistic time series forecastingTraffic Regular (test)
Average NCRPS0.556
10
Probabilistic time series forecastingElectricity Irregular (test)
Average NCRPS0.278
10
Probabilistic time series forecastingTraffic Irregular (test)
Avg NCRPS0.476
10
Showing 10 of 17 rows

Other info

Code

Follow for update