Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Fine-Tuning Masked Diffusion for Provable Self-Correction

About

A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).

Jaeyeon Kim, Seunggeun Kim, Taekyun Lee, David Z. Pan, Hyeji Kim, Sham Kakade, Sitan Chen• 2025

Related benchmarks

TaskDatasetResultRank
Unconditional Text GenerationOpenWebText
Gen. PPL15.3
56
CodingHumanEval
Pass@142.7
52
Unconditional GenerationOpenWebText (OWT) L=1024 (held-out)
MAUVE0.527
45
CodeMBPP
Pass@132.3
43
Showing 4 of 4 rows

Other info

Follow for update