Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

LiNeS: Post-training Layer Scaling Prevents Forgetting and Enhances Model Merging

About

Fine-tuning pre-trained models has become the standard approach to endow them with specialized knowledge, but it poses fundamental challenges. In particular, \textit{(i)} fine-tuning often leads to catastrophic forgetting, where improvements on a target domain degrade generalization on other tasks, and \textit{(ii)} merging fine-tuned checkpoints from disparate tasks can lead to significant performance loss. To address these challenges, we introduce LiNeS, Layer-increasing Network Scaling, a post-training editing technique designed to preserve pre-trained generalization while enhancing fine-tuned task performance. LiNeS scales parameter updates linearly based on their layer depth within the network, maintaining shallow layers close to their pre-trained values to preserve general features while allowing deeper layers to retain task-specific representations. In multi-task model merging scenarios, layer-wise scaling of merged parameters reduces negative task interference. LiNeS demonstrates significant improvements in both single-task and multi-task settings across various benchmarks in vision and natural language processing. It mitigates forgetting, enhances out-of-distribution generalization, integrates seamlessly with existing multi-task model merging baselines improving their performance across benchmarks and model sizes, and can boost generalization when merging LLM policies aligned with different rewards via RLHF. Our method is simple to implement, computationally efficient and complementary to many existing techniques. Our source code is available at https://github.com/wang-kee/LiNeS

Ke Wang, Nikolaos Dimitriadis, Alessandro Favero, Guillermo Ortiz-Jimenez, Francois Fleuret, Pascal Frossard• 2024

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningGSM8K
Accuracy56.5
1398
Depth EstimationNYU Depth V2--
209
Image Classification20 Vision Classification Tasks
Average Accuracy75.7
131
Multiple-choice Question AnsweringMMLU-Pro
MMLU-Pro Overall Accuracy36.7
130
Image Classification14 Vision Tasks
Average Accuracy80.4
121
Multiple-choice Question AnsweringSciQ
Accuracy95.2
91
Surface Normal EstimationNYU V2
Mean Angular Error29.1
65
Image Classification8 vision tasks Average
Average Accuracy86.9
53
Mathematical ReasoningGSM8K Platinum
Accuracy58.9
37
Semantic segmentationNYU V2
mIoU36.2
30
Showing 10 of 12 rows

Other info

Follow for update