Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

About

Large Language Models (LLMs) are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains. Code is available at https://github.com/NVlabs/MaskLLM.

Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang• 2024

Related benchmarks

TaskDatasetResultRank
Language ModelingWikiText2
Perplexity6.78
3785
Language ModelingWikiText-2 (test)
PPL5.85
2333
Language ModelingWikiText-2
Perplexity (PPL)5.85
2320
Commonsense ReasoningWinoGrande
Accuracy69.14
1442
Language ModelingC4 (val)
PPL11.15
737
Question AnsweringOpenBookQA
Accuracy30.6
465
Physical Interaction Question AnsweringPIQA
Accuracy76.22
415
Science Question AnsweringARC Challenge
Accuracy43.94
354
Question AnsweringOpenBookQA
Accuracy26.8
305
Science Question AnsweringARC-E
Accuracy75.93
240
Showing 10 of 18 rows

Other info

Code

Follow for update