Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Building Normalizing Flows with Stochastic Interpolants

About

A generative model based on a continuous-time normalizing flow between any pair of base and target probability densities is proposed. The velocity field of this flow is inferred from the probability current of a time-dependent density that interpolates between the base and the target in finite time. Unlike conventional normalizing flow inference methods based the maximum likelihood principle, which require costly backpropagation through ODE solvers, our interpolant approach leads to a simple quadratic loss for the velocity itself which is expressed in terms of expectations that are readily amenable to empirical estimation. The flow can be used to generate samples from either the base or target, and to estimate the likelihood at any time along the interpolant. In addition, the flow can be optimized to minimize the path length of the interpolant density, thereby paving the way for building optimal transport maps. In situations where the base is a Gaussian density, we also show that the velocity of our normalizing flow can also be used to construct a diffusion model to sample the target as well as estimate its score. However, our approach shows that we can bypass this diffusion completely and work at the level of the probability flow with greater simplicity, opening an avenue for methods based solely on ordinary differential equations as an alternative to those based on stochastic differential equations. Benchmarking on density estimation tasks illustrates that the learned flow can match and surpass conventional continuous flows at a fraction of the cost, and compares well with diffusions on image generation on CIFAR-10 and ImageNet $32\times32$. The method scales ab-initio ODE flows to previously unreachable image resolutions, demonstrated up to $128\times128$.

Michael S. Albergo, Eric Vanden-Eijnden• 2022

Related benchmarks

TaskDatasetResultRank
Density EstimationCIFAR-10 (test)
Bits/dim2.99
134
Density EstimationImageNet 32x32 (test)
Bits per Sub-pixel3.48
66
Likelihood EstimationCIFAR-10 (test)
NLL (BPD)2.99
24
Image GenerationImageNet-32
FID8.49
20
Likelihood EstimationImageNet32 downsampled (test)
BPD3.48
11
Image-to-Image TranslationHandbags to Shoes (test)
FID15.87
9
Image-to-Image TranslationCelebA Male to Female (test)
FID16.39
9
Showing 7 of 7 rows

Other info

Follow for update