Adjoint Sampling: Highly Scalable Diffusion Samplers via Adjoint Matching
About
We introduce Adjoint Sampling, a highly scalable and efficient algorithm for learning diffusion processes that sample from unnormalized densities, or energy functions. It is the first on-policy approach that allows significantly more gradient updates than the number of energy evaluations and model samples, allowing us to scale to much larger problem settings than previously explored by similar methods. Our framework is theoretically grounded in stochastic optimal control and shares the same theoretical guarantees as Adjoint Matching, being able to train without the need for corrective measures that push samples towards the target distribution. We show how to incorporate key symmetries, as well as periodic boundary conditions, for modeling molecules in both cartesian and torsional coordinates. We demonstrate the effectiveness of our approach through extensive experiments on classical energy functions, and further scale up to neural network-based energy models where we perform amortized conformer generation across many molecular systems. To encourage further research in developing highly scalable sampling methods, we plan to open source these challenging benchmarks, where successful methods can directly impact progress in computational chemistry.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| n-body particle system sampling | DW-4 d = 8 | W2 Distance0.62 | 29 | |
| n-body particle system sampling | LJ-13 (d = 39) | W2 Distance1.67 | 21 | |
| Toy target distribution sampling | GMM40 d = 50 | W2 (Entropy Regulated, eps=0.05)1.90e+4 | 18 | |
| n-body particle system sampling | LJ-55 d = 165 | W24.5 | 16 | |
| Target Distribution Sampling | Many-Well 5D | Sinkhorn Distance0.32 | 11 | |
| Learning Continuous Target Distributions | MoS d = 50 | Sinkhorn Cost2.18e+3 | 11 | |
| Boltzmann Distribution Sampling | LJ-13 | E(·) W22.4 | 6 | |
| Boltzmann Distribution Sampling | LJ-55 | Expected Value W230.83 | 5 | |
| Molecular Boltzmann Distribution Sampling | Alanine Dipeptide | KL Divergence (phi)0.09 | 5 |