S$^{2}$FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity
About
Current PEFT methods for LLMs can achieve either high quality, efficient training, or scalable serving, but not all three simultaneously. To address this limitation, we investigate sparse fine-tuning and observe a remarkable improvement in generalization ability. Utilizing this key insight, we propose a family of Structured Sparse Fine-Tuning (S$^{2}$FT) methods for LLMs, which concurrently achieve state-of-the-art fine-tuning performance, training efficiency, and inference scalability. S$^{2}$FT accomplishes this by "selecting sparsely and computing densely". It selects a few heads and channels in the MHA and FFN modules for each Transformer block, respectively. Next, it co-permutes weight matrices on both sides of the coupled structures in LLMs to connect the selected components in each layer into a dense submatrix. Finally, S$^{2}$FT performs in-place gradient updates on all submatrices. Through theoretical analysis and empirical results, our method prevents forgetting while simplifying optimization, delivers SOTA performance on both commonsense and arithmetic reasoning with 4.6% and 1.3% average improvements compared to LoRA, and surpasses full FT by 11.5% when generalizing to various domains after instruction tuning. Using our partial backpropagation algorithm, S$^{2}$FT saves training memory up to 3$\times$ and improves latency by 1.5-2.7$\times$ compared to full FT, while delivering an average 10% improvement over LoRA on both metrics. We further demonstrate that the weight updates in S$^{2}$FT can be decoupled into adapters, enabling effective fusion, fast switch, and efficient parallelism for serving multiple fine-tuned models.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | Common Sense Reasoning Tasks | Avg Score86.6 | 321 | |
| Commonsense Reasoning | Commonsense Reasoning (BoolQ, PIQA, SIQA, HellaS., WinoG., ARC-e, ARC-c, OBQA) | BoolQ Accuracy73.3 | 223 | |
| Code-Specific Instruction Tuning Evaluation | Magicoder Evaluation Suite | ARC-C Accuracy51.74 | 48 | |
| Instruction Following | MT-Bench (test) | Overall Score5.27 | 35 | |
| Instruction Fine-tuning | MetaMathQA Fine-tuning Evaluation Suite (ARC-C, PIQA, MMLU, HE, GSM8K) (test) | ARC-C Accuracy49.51 | 32 | |
| Arithmetic Reasoning | Arithmetic Reasoning Benchmarks (MultiArith, GSM8K, AddSub, AQuA, SingleEQ, SVAMP, MAWPS) MATH-10K fine-tuned (test) | MultiArith Accuracy99.67 | 24 | |
| Math Reasoning | Math Reasoning Tasks (MultiArith, GSM8K, AddSub, AQUA, SingleEq, SVAMP, MAWPS) (test) | MultiArith99.7 | 23 | |
| Safety | T3 | T3 Score83.4 | 21 | |
| Instruction following and reasoning | Chat and Instruction-following Suite IFEval, AE2, MTB, GSM8K | IFEval0.695 | 5 | |
| General language understanding and generation | Source language | MT Score36.7 | 5 |