Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Bypass Back-propagation: Optimization-based Structural Pruning for Large Language Models via Policy Gradient

About

Recent Large-Language Models (LLMs) pruning methods typically operate at the post-training phase without the expensive weight finetuning, however, their pruning criteria often rely on heuristically hand-crafted metrics, potentially leading to suboptimal performance. We instead propose a novel optimization-based structural pruning that learns the pruning masks in a probabilistic space directly by optimizing the loss of the pruned model. To preserve efficiency, our method eliminates the back-propagation through the LLM per se during optimization, requiring only the forward pass of the LLM. We achieve this by learning an underlying Bernoulli distribution to sample binary pruning masks, where we decouple the Bernoulli parameters from LLM loss, facilitating efficient optimization via policy gradient estimator without back-propagation. Thus, our method can 1) support global and heterogeneous pruning (i.e., automatically determine different redundancy for different layers), and 2) optionally initialize with a metric-based method (for our Bernoulli distributions). Extensive experiments conducted on LLaMA, LLaMA-2, LLaMA-3, Vicuna, and Mistral models using the C4 and WikiText2 datasets demonstrate the promising performance of our method in efficiency and effectiveness. Code is available at https://github.com/ethanygao/backprop-free_LLM_pruning.

Yuan Gao, Zujing Liu, Weizhong Zhang, Bo Du, Gui-Song Xia• 2024

Related benchmarks

TaskDatasetResultRank
Language ModelingWikiText2
Perplexity18.73
2839
Language ModelingWikiText-2 (test)
PPL21.93
1949
Multi-task Language UnderstandingMMLU
Accuracy56.38
876
Language ModelingWikiText
PPL25.34
732
Language ModelingWikiText2 (val)
Perplexity (PPL)38.99
387
Language ModelingWikiText2 v1 (test)
Perplexity19.7
383
Zero-shot ReasoningReasoning Suite Zero-shot (PIQA, HellaSwag, WinoGrande, ARC-e, ARC-c) (val test)
PIQA76.49
177
Word PredictionLAMBADA
Accuracy62.63
148
Zero-shot Common Sense ReasoningZero-shot Suite (PIQA, HellaSwag, WinoGrande, ARC-e, ARC-c) (test)
PIQA72.25
95
Zero-shot AccuracyARC Easy
Zero-shot Acc (ARC Easy)66.03
63
Showing 10 of 15 rows

Other info

Follow for update