Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Improving Sampling for Masked Diffusion Models via Information Gain

About

Masked Diffusion Models (MDMs) offer greater flexibility in decoding order than autoregressive models but require careful planning to achieve high-quality generation. Existing samplers typically adopt greedy heuristics, prioritizing positions with the highest local certainty to decode at each step. Through failure case analysis, we identify a fundamental limitation of this approach: it neglects the downstream impact of current decoding choices on subsequent steps and fails to minimize cumulative uncertainty. In particular, these methods do not fully exploit the non-causal nature of MDMs, which enables evaluating how a decoding decision reshapes token probabilities/uncertainty across all remaining masked positions. To bridge this gap, we propose the Info-Gain Sampler, a principled decoding framework that balances immediate uncertainty with information gain over future masked tokens. Extensive evaluations across diverse architectures and tasks (reasoning, coding, creative writing, and image generation) demonstrate that Info-Gain Sampler consistently outperforms existing samplers for MDMs. For instance, it achieves a 3.6% improvement in average accuracy on reasoning tasks and a 63.1% win-rate in creative writing. Notably, on reasoning tasks it reduces cumulative uncertainty from 78.4 to 48.6, outperforming the best baseline by a large margin. The code will be available at https://github.com/yks23/Information-Gain-Sampler.

Kaisen Yang, Jayden Teoh, Kaicheng Yang, Yitong Zhang, Alex Lamb• 2026

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningGSM8K
Accuracy83.3
983
Mathematical ReasoningGSM8K (test)
Accuracy88.9
797
Code GenerationMBPP
Accuracy48.4
120
Text-to-Image GenerationGenEval
Two Objects68.7
87
Code GenerationHumanEval
Accuracy (%)63.8
77
Mathematical ReasoningMATH 500
Accuracy59.6
73
PlanningSudoku
Accuracy84.4
68
PlanningCountdown
Accuracy45.2
68
Mathematical ReasoningMATH500
Accuracy51.3
57
ReasoningReasoning Tasks Suite GSM8K, MATH500, HumanEval, MBPP
Average Accuracy60.3
20
Showing 10 of 13 rows

Other info

Follow for update