Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

LoRA-GA: Low-Rank Adaptation with Gradient Approximation

About

Fine-tuning large-scale pretrained models is prohibitively expensive in terms of computational and memory costs. LoRA, as one of the most popular Parameter-Efficient Fine-Tuning (PEFT) methods, offers a cost-effective alternative by fine-tuning an auxiliary low-rank model that has significantly fewer parameters. Although LoRA reduces the computational and memory requirements significantly at each iteration, extensive empirical evidence indicates that it converges at a considerably slower rate compared to full fine-tuning, ultimately leading to increased overall compute and often worse test performance. In our paper, we perform an in-depth investigation of the initialization method of LoRA and show that careful initialization (without any change of the architecture and the training algorithm) can significantly enhance both efficiency and performance. In particular, we introduce a novel initialization method, LoRA-GA (Low Rank Adaptation with Gradient Approximation), which aligns the gradients of low-rank matrix product with those of full fine-tuning at the first step. Our extensive experiments demonstrate that LoRA-GA achieves a convergence rate comparable to that of full fine-tuning (hence being significantly faster than vanilla LoRA as well as various recent improvements) while simultaneously attaining comparable or even better performance. For example, on the subset of the GLUE dataset with T5-Base, LoRA-GA outperforms LoRA by 5.69% on average. On larger models such as Llama 2-7B, LoRA-GA shows performance improvements of 0.34, 11.52%, and 5.05% on MT-bench, GSM8K, and Human-eval, respectively. Additionally, we observe up to 2-4 times convergence speed improvement compared to vanilla LoRA, validating its effectiveness in accelerating convergence and enhancing model performance. Code is available at https://github.com/Outsider565/LoRA-GA.

Shaowen Wang, Linxi Yu, Jian Li• 2024

Related benchmarks

TaskDatasetResultRank
Code GenerationHumanEval
Pass@119.44
1036
Mathematical ReasoningGSM8K (test)
Accuracy55.12
770
Code GenerationHumanEval (test)
Pass@123.05
506
Multi-turn Dialogue EvaluationMT-Bench--
447
Language ModelingWikiText2 (val)
Perplexity (PPL)21.44
387
Natural Language UnderstandingGLUE (val)
SST-294.11
191
Mathematical ReasoningGSM8K (val)
Accuracy50.47
81
General Language UnderstandingGLUE
Accuracy91.9
66
Code GenerationMBPP
Pass@1 Accuracy23.05
59
Mathematical ReasoningMATH (val)
Accuracy7.13
48
Showing 10 of 17 rows

Other info

Code

Follow for update