Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Proximal Diffusion Neural Sampler

About

The task of learning a diffusion-based neural sampler for drawing samples from an unnormalized target distribution can be viewed as a stochastic optimal control problem on path measures. However, the training of neural samplers can be challenging when the target distribution is multimodal with significant barriers separating the modes, potentially leading to mode collapse. We propose a framework named Proximal Diffusion Neural Sampler (PDNS) that addresses these challenges by tackling the stochastic optimal control problem via proximal point method on the space of path measures. PDNS decomposes the learning process into a series of simpler subproblems that create a path gradually approaching the desired distribution. This staged procedure traces a progressively refined path to the desired distribution and promotes thorough exploration across modes. For a practical and efficient realization, we instantiate each proximal step with a proximal weighted denoising cross-entropy (WDCE) objective. We demonstrate the effectiveness and robustness of PDNS through extensive experiments on both continuous and discrete sampling tasks, including challenging scenarios in molecular dynamics and statistical physics. Our code is available at https://github.com/AlexandreGUO2001/PDNS.

Wei Guo, Jaemoo Choi, Yuchen Zhu, Molei Tao, Yongxin Chen• 2025

Related benchmarks

TaskDatasetResultRank
n-body particle system samplingDW-4 d = 8
W2 Distance0.51
29
Target Distribution SamplingFunnel 10D
Sinkhorn Distance129.5
29
n-body particle system samplingLJ-13 (d = 39)
W2 Distance1.57
21
Toy target distribution samplingGMM40 d = 50
W2 (Entropy Regulated, eps=0.05)327.8
18
n-body particle system samplingLJ-55 d = 165
W23.95
16
Learning Continuous Target DistributionsMoS d = 50
Sinkhorn Cost353.1
11
Target Distribution SamplingMany-Well 5D
Sinkhorn Distance0.08
11
Sampling from lattice Ising modelsLattice Ising model beta=0.28 L=24 (high-temperature)
Delta Mag.0.0039
6
Molecular Boltzmann Distribution SamplingAlanine Dipeptide
KL Divergence (phi)0.02
5
Discrete samplingLattice Potts model L=16, N=4, beta=1.3
Delta Mag.8.40e-4
4
Showing 10 of 26 rows

Other info

Follow for update