Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Reflected Diffusion Models

About

Score-based diffusion models learn to reverse a stochastic differential equation that maps data to noise. However, for complex tasks, numerical error can compound and result in highly unnatural samples. Previous work mitigates this drift with thresholding, which projects to the natural data domain (such as pixel space for images) after each diffusion step, but this leads to a mismatch between the training and generative processes. To incorporate data constraints in a principled manner, we present Reflected Diffusion Models, which instead reverse a reflected stochastic differential equation evolving on the support of the data. Our approach learns the perturbed score function through a generalized score matching loss and extends key components of standard diffusion models including diffusion guidance, likelihood-based training, and ODE sampling. We also bridge the theoretical gap with thresholding: such schemes are just discretizations of reflected SDEs. On standard image benchmarks, our method is competitive with or surpasses the state of the art without architectural modifications and, for classifier-free guidance, our approach enables fast exact sampling with ODEs and produces more faithful samples under high guidance weight.

Aaron Lou, Stefano Ermon• 2023

Related benchmarks

TaskDatasetResultRank
Density EstimationCIFAR-10 (test)
Bits/dim2.68
134
Image GenerationCIFAR-10 (train/test)
FID2.72
78
Generative ModelingCIFAR-10 8-bit color (test)
Bits per Dimension2.68
15
Likelihood EstimationImageNet32 downsampled (test)
BPD3.74
11
Likelihood EstimationCIFAR-10 original (test)
BPD2.68
10
Density EstimationImageNet (test)
BPD3.74
5
Constrained Generative ModelingHypercube [0, 1]^d d=2 (test)
Sliced Wasserstein Distance0.0375
4
Constrained Generative ModelingHypercube [0, 1]^d d=3 (test)
Sliced Wasserstein0.0658
4
Constrained Generative ModelingHypercube [0, 1]^d d=6 (test)
SWD2.77
4
Constrained Generative ModelingHypercube [0, 1]^d d=8 (test)
Sliced Wasserstein Distance0.035
4
Showing 10 of 11 rows

Other info

Code

Follow for update