Bypass Back-propagation: Optimization-based Structural Pruning for Large Language Models via Policy Gradient
About
Recent Large-Language Models (LLMs) pruning methods typically operate at the post-training phase without the expensive weight finetuning, however, their pruning criteria often rely on heuristically hand-crafted metrics, potentially leading to suboptimal performance. We instead propose a novel optimization-based structural pruning that learns the pruning masks in a probabilistic space directly by optimizing the loss of the pruned model. To preserve efficiency, our method eliminates the back-propagation through the LLM per se during optimization, requiring only the forward pass of the LLM. We achieve this by learning an underlying Bernoulli distribution to sample binary pruning masks, where we decouple the Bernoulli parameters from LLM loss, facilitating efficient optimization via policy gradient estimator without back-propagation. Thus, our method can 1) support global and heterogeneous pruning (i.e., automatically determine different redundancy for different layers), and 2) optionally initialize with a metric-based method (for our Bernoulli distributions). Extensive experiments conducted on LLaMA, LLaMA-2, LLaMA-3, Vicuna, and Mistral models using the C4 and WikiText2 datasets demonstrate the promising performance of our method in efficiency and effectiveness. Code is available at https://github.com/ethanygao/backprop-free_LLM_pruning.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | WikiText2 | Perplexity18.73 | 1875 | |
| Language Modeling | WikiText-2 (test) | PPL21.93 | 1541 | |
| Multi-task Language Understanding | MMLU | Accuracy56.38 | 842 | |
| Language Modeling | WikiText | PPL25.34 | 479 | |
| Language Modeling | WikiText2 v1 (test) | Perplexity19.7 | 341 | |
| Language Modeling | WikiText2 (val) | Perplexity (PPL)38.99 | 277 | |
| Zero-shot Reasoning | Reasoning Suite Zero-shot (PIQA, HellaSwag, WinoGrande, ARC-e, ARC-c) (val test) | PIQA76.49 | 119 | |
| Word Prediction | LAMBADA | Accuracy62.63 | 112 | |
| Zero-shot Common Sense Reasoning | Zero-shot Suite (PIQA, HellaSwag, WinoGrande, ARC-e, ARC-c) (test) | PIQA72.25 | 95 | |
| Zero-shot Accuracy | ARC Easy | Zero-shot Acc (ARC Easy)66.03 | 63 |