Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding

About

Autoregressive decoding in large language models (LLMs) requires $\mathcal{O}(n)$ sequential steps for $n$ tokens, fundamentally limiting inference throughput. Recent diffusion-based LLMs (dLLMs) enable parallel token generation through iterative denoising. However, current parallel decoding strategies rely on fixed, input-agnostic heuristics (e.g., confidence thresholds), which fail to adapt to input-specific characteristics, resulting in suboptimal speed-quality trade-offs across diverse NLP tasks. In this work, we explore a more flexible and dynamic approach to parallel decoding. We propose Learning to Parallel Decode (Learn2PD), a framework that trains a lightweight and adaptive filter model to predict, for each token position, whether the current prediction matches the final output. This learned filter approximates an oracle parallel decoding strategy that unmasks tokens only when correctly predicted. Importantly, the filter model is learned in a post-training manner, requiring only a small amount of computation to optimize it (minute-level GPU time). Additionally, we introduce End-of-Text Prediction (EoTP) to detect decoding completion at the end of sequence, avoiding redundant decoding of padding tokens. Experiments on the LLaDA benchmark demonstrate that our method achieves up to 22.58$\times$ speedup without any performance drop, and up to 57.51$\times$ when combined with KV-Cache.

Wenrui Bao, Zhiben Chen, Dan Xu, Yuzhang Shang• 2025

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningGSM8K
Accuracy (Acc)78.3
337
Mathematical ReasoningMATH
Accuracy (%)78.3
52
Mathematical ReasoningGSM8K 5-shot (test)
Strict Match Accuracy79.1
47
GenerationMath Domain
Average Generation Time (s)2.86
40
GenerationQA domain
Average Generation Time (s)2.47
40
GenerationCoding domain
Average Wall-Clock Time (s)4.14
40
Code GenerationCoding
Pass@148.8
40
Question AnsweringQA
Average Rank3.02
40
Code GenerationHumanEval 0-shot (test)
Accuracy40.3
23
Mathematical ReasoningMATH 4-shot (test)
Accuracy32.3
22
Showing 10 of 15 rows

Other info

Follow for update