Mirror Diffusion Models for Constrained and Watermarked Generation
About
Modern successes of diffusion models in learning complex, high-dimensional data distributions are attributed, in part, to their capability to construct diffusion processes with analytic transition kernels and score functions. The tractability results in a simulation-free framework with stable regression losses, from which reversed, generative processes can be learned at scale. However, when data is confined to a constrained set as opposed to a standard Euclidean space, these desirable characteristics appear to be lost based on prior attempts. In this work, we propose Mirror Diffusion Models (MDM), a new class of diffusion models that generate data on convex constrained sets without losing any tractability. This is achieved by learning diffusion processes in a dual space constructed from a mirror map, which, crucially, is a standard Euclidean space. We derive efficient computation of mirror maps for popular constrained sets, such as simplices and $\ell_2$-balls, showing significantly improved performance of MDM over existing methods. For safety and privacy purposes, we also explore constrained sets as a new mechanism to embed invisible but quantitative information (i.e., watermarks) in generated data, for which MDM serves as a compelling approach. Our work brings new algorithmic opportunities for learning tractable diffusion on complex domains. Our code is available at https://github.com/ghliu/mdm
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Constrained Generation | Simplices (test) | W1 Distance1 | 15 | |
| Unconditional Watermarked Generation | FFHQ 64x64 (train) | FID (50k)2.54 | 7 | |
| Unconditional Watermarked Generation | AFHQ 64x64 v2 (train) | FID (50k Samples)2.1 | 7 | |
| Constrained Generative Modeling | Hypercube [0, 1]^d d=3 (test) | Sliced Wasserstein0.0192 | 4 | |
| Constrained Generative Modeling | Hypercube [0, 1]^d d=6 (test) | SWD1.75 | 4 | |
| Constrained Generative Modeling | Hypercube [0, 1]^d d=8 (test) | Sliced Wasserstein Distance0.0185 | 4 | |
| Constrained Generative Modeling | Hypercube [0, 1]^d d=2 (test) | Sliced Wasserstein Distance0.03 | 4 | |
| Constrained Generative Modeling | Hypercube [0, 1]^d d=20 (test) | SWD0.0335 | 4 | |
| Constrained Generation | l2-ball constrained set (d=2, Gaussian Mixture) 1.0 (synthetic) | Constraint Violation Rate0.00e+0 | 2 | |
| Constrained Generation | l2-ball constrained set d=2 Spiral 1.0 (synthetic) | Constraint Violation Rate0.00e+0 | 2 |