Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Fine-tuning can cripple your foundation model; preserving features may be the solution

About

Pre-trained foundation models, due to their enormous capacity and exposure to vast amounts of data during pre-training, are known to have learned plenty of real-world concepts. An important step in making these pre-trained models effective on downstream tasks is to fine-tune them on related datasets. While various fine-tuning methods have been devised and have been shown to be highly effective, we observe that a fine-tuned model's ability to recognize concepts on tasks $\textit{different}$ from the downstream one is reduced significantly compared to its pre-trained counterpart. This is an undesirable effect of fine-tuning as a substantial amount of resources was used to learn these pre-trained concepts in the first place. We call this phenomenon ''concept forgetting'' and via experiments show that most end-to-end fine-tuning approaches suffer heavily from this side effect. To this end, we propose a simple fix to this problem by designing a new fine-tuning method called $\textit{LDIFS}$ (short for $\ell_2$ distance in feature space) that, while learning new concepts related to the downstream task, allows a model to preserve its pre-trained knowledge as well. Through extensive experiments on 10 fine-tuning tasks we show that $\textit{LDIFS}$ significantly reduces concept forgetting. Additionally, we show that LDIFS is highly effective in performing continual fine-tuning on a sequence of tasks as well, in comparison with both fine-tuning as well as continual learning baselines.

Jishnu Mukhoti, Yarin Gal, Philip H.S. Torr, Puneet K. Dokania• 2023

Related benchmarks

TaskDatasetResultRank
Safety EvaluationBeaverTails (test)
Harmful Score16.4
110
Image ClassificationDTD (Describable Textures Dataset)
Accuracy78.99
80
Topic ClassificationAGNews
FA Score0.725
58
Safety and Utility EvaluationSafety and Utility evaluation suite (test)
HS Score4.1
40
Image ClassificationFMoW OOD (test)
Worst Group Accuracy42.8
10
Image ClassificationFMoW ID (test)
Accuracy69.2
10
Image ClassificationiWildCam ID (test)
F1-macro48.1
10
Image ClassificationiWildCam OOD (test)
F1-macro34.7
10
ReasoningVQA-RAD
Correctness46.31
6
ReasoningFlowers
Correctness81.31
6
Showing 10 of 12 rows

Other info

Follow for update