Sparse Layer Sharpness-Aware Minimization for Efficient Fine-Tuning
About
Sharpness-aware minimization (SAM) seeks the minima with a flat loss landscape to improve the generalization performance in machine learning tasks, including fine-tuning. However, its extra parameter perturbation step doubles the computation cost, which becomes the bottleneck of SAM in the practical implementation. In this work, we propose an approach SL-SAM to break this bottleneck by introducing the sparse technique to layers. Our key innovation is to frame the dynamic selection of layers for both the gradient ascent (perturbation) and descent (update) steps as a multi-armed bandit problem. At the beginning of each iteration, SL-SAM samples a part of the layers of the model according to the gradient norm to participate in the backpropagation of the following parameter perturbation and update steps, thereby reducing the computation complexity. We then provide the analysis to guarantee the convergence of SL-SAM. In the experiments of fine-tuning models in several tasks, SL-SAM achieves the performances comparable to the state-of-the-art baselines, including a \#1 rank on LLM fine-tuning. Meanwhile, SL-SAM significantly reduces the ratio of active parameters in backpropagation compared to vanilla SAM (SL-SAM activates 47\%, 22\% and 21\% parameters on the vision, moderate and large language model respectively while vanilla SAM always activates 100\%), verifying the efficiency of our proposed algorithm.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Commonsense Reasoning | HellaSwag | Accuracy70.33 | 1460 | |
| Question Answering | OpenBookQA | Accuracy37.2 | 465 | |
| Natural Language Inference | RTE | Accuracy73.29 | 367 | |
| Boolean Question Answering | BoolQ | Accuracy79.39 | 307 | |
| Science Question Answering | ARC Challenge | Accuracy44.11 | 234 | |
| Natural Language Understanding | GLUE (test dev) | MRPC Accuracy93.45 | 81 | |
| Multiple-choice Question Answering | MMLU | STEM Accuracy49.83 | 13 | |
| Linguistic Acceptability | COLA | Max Memory (MB)3.08e+3 | 5 | |
| Natural Language Inference | MNLI | Max Memory (MB)6.46e+3 | 5 | |
| Fine-tuning | Open-Platypus | Max Memory (MB)4.74e+4 | 4 |