Flat-LoRA: Low-Rank Adaptation over a Flat Loss Landscape
About
Fine-tuning large-scale pre-trained models is prohibitively expensive in terms of computation and memory costs. Low-Rank Adaptation (LoRA), a popular Parameter-Efficient Fine-Tuning (PEFT) method, offers an efficient solution by optimizing only low-rank matrices. Despite recent progress in improving LoRA's performance, the relationship between the LoRA optimization space and the full parameter space is often overlooked. A solution that appears flat in the loss landscape of the LoRA space may still exhibit sharp directions in the full parameter space, potentially compromising generalization. We introduce Flat-LoRA, which aims to identify a low-rank adaptation situated in a flat region of the full parameter space. Instead of adopting the well-established sharpness-aware minimization approach, which incurs significant computation and memory overheads, we employ a Bayesian expectation loss objective to preserve training efficiency. Further, we design a refined random perturbation generation strategy for improved performance and carefully manage memory overhead using random seeds. Experiments across diverse tasks-including mathematical reasoning, coding abilities, dialogue generation, instruction following, and text-to-image generation-demonstrate that Flat-LoRA improves both in-domain and out-of-domain generalization. Code is available at https://github.com/nblt/Flat-LoRA.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K | Accuracy59.44 | 1398 | |
| Dialogue | MT-Bench | MT-Bench Score5.98 | 41 | |
| Instruction Following | BBH | -- | 40 | |
| Code Generation | HumanEval | Pass@126.67 | 36 | |
| Instruction Following | MMLU | MMLU Accuracy63.67 | 20 | |
| Instruction Following | DROP | DROP Score50.44 | 20 | |
| Instruction Following | HEval | PASS@144.31 | 12 | |
| Instruction Following | Instruction-following Evaluation Suite (MMLU, DROP, HEval, BBH) (test) | MMLU79.51 | 11 | |
| Code Synthesis | HumanEval | pass@124.56 | 11 | |
| Mathematical Reasoning | GSM8K | Accuracy56.25 | 7 |