Meta Flow Matching: Integrating Vector Fields on the Wasserstein Manifold
About
Numerous biological and physical processes can be modeled as systems of interacting entities evolving continuously over time, e.g. the dynamics of communicating cells or physical particles. Learning the dynamics of such systems is essential for predicting the temporal evolution of populations across novel samples and unseen environments. Flow-based models allow for learning these dynamics at the population level - they model the evolution of the entire distribution of samples. However, current flow-based models are limited to a single initial population and a set of predefined conditions which describe different dynamics. We argue that multiple processes in natural sciences have to be represented as vector fields on the Wasserstein manifold of probability densities. That is, the change of the population at any moment in time depends on the population itself due to the interactions between samples. In particular, this is crucial for personalized medicine where the development of diseases and their respective treatment response depend on the microenvironment of cells specific to each patient. We propose Meta Flow Matching (MFM), a practical approach to integrate along these vector fields on the Wasserstein manifold by amortizing the flow model over the initial populations. Namely, we embed the population of samples using a Graph Neural Network (GNN) and use these embeddings to train a Flow Matching model. This gives MFM the ability to generalize over the initial distributions, unlike previously proposed methods. We demonstrate the ability of MFM to improve the prediction of individual treatment responses on a large-scale multi-patient single-cell drug screen dataset.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Trajectory Generation | Mouse Organogenesis Spatiotemporal Atlas (MOSTA) Stereo-seq Interpolation (time 5) | Weighted W21.839 | 12 | |
| Trajectory Extrapolation | Axolotl Brain Regeneration Extrapolation (last holdout time point) | Weighted W26.868 | 12 | |
| Trajectory Interpolation | Axolotl Brain Regeneration (middle holdout time point) | Weighted W25.809 | 12 | |
| Causal Perturbation Prediction | SCM dataset linear SCMs (test) | W213.72 | 10 | |
| Measure-to-measure regression | Multi-measure objects Diffusion corruption (unseen measures) | W1 Score0.2306 | 9 | |
| Measure-to-measure regression | Multi-measure objects Kernel interactions corruption (unseen measures) | W10.2292 | 9 | |
| McKean-Vlasov system prediction | Kuramoto 100D (test) | W17.303 | 8 | |
| McKean-Vlasov system prediction | FitzHugh-Nagumo 100D (test) | W1 Distance18.683 | 8 | |
| McKean-Vlasov system prediction | Atlas 100D (test) | W1 Distance24.891 | 8 | |
| Causal Perturbation Prediction | Real single-cell melanoma data (test) | W221.12 | 7 |