Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Scaling up Masked Diffusion Models on Text

About

Masked diffusion models (MDMs) have shown promise in language modeling, yet their scalability and effectiveness in core language tasks, such as text generation and language understanding, remain underexplored. This paper establishes the first scaling law for MDMs, demonstrating a scaling rate comparable to autoregressive models (ARMs) and a relatively small compute gap. Motivated by their scalability, we train a family of MDMs with up to 1.1 billion (B) parameters to systematically evaluate their performance against ARMs of comparable or larger sizes. Fully leveraging the probabilistic formulation of MDMs, we propose a simple yet effective unsupervised classifier-free guidance that effectively exploits large-scale unpaired data, boosting performance for conditional inference. In language understanding, the 1.1B MDM outperforms the 1.1B TinyLlama model trained on the same data across four of eight zero-shot benchmarks. Notably, it achieves competitive math reasoning ability with the 7B Llama-2 model on the GSM8K dataset. In text generation, MDMs with 16 times more pre-training time offer a flexible trade-off against ARMs with the accelerated sampling technique KV-Cache: MDMs match ARMs in performance while being 1.4 times faster during sampling. Moreover, MDMs address challenging tasks for ARMs by effectively handling bidirectional reasoning and adapting to temporal shifts in data. Notably, a 1.1B MDM breaks the reverse curse encountered by much larger ARMs with significantly more data and computation, such as 13B Llama-2 and 175B GPT-3. Our code is available at https://github.com/ML-GSAI/SMDM.

Shen Nie, Fengqi Zhu, Chao Du, Tianyu Pang, Qian Liu, Guangtao Zeng, Min Lin, Chongxuan Li• 2024

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningPIQA
Accuracy60.3
647
Question AnsweringOBQA
Accuracy27
276
Question AnsweringARC-E
Accuracy37.4
242
Question AnsweringBoolQ
Accuracy61.5
240
Commonsense ReasoningSIQA
Accuracy37.9
96
Reading ComprehensionRACE
Accuracy29.3
34
Scaling Law EstimationScaling Law Analysis Scaling coefficients
Alpha M Coefficient0.644
11
Mathematical ReasoningGSM8K (test)
Accuracy58.5
6
Showing 8 of 8 rows

Other info

Follow for update