Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Merging Models with Fisher-Weighted Averaging

About

Averaging the parameters of models that have the same architecture and initialization can provide a means of combining their respective capabilities. In this paper, we take the perspective that this "merging" operation can be seen as choosing parameters that approximately maximize the joint likelihood of the posteriors of the models' parameters. Computing a simple average of the models' parameters therefore corresponds to making an isotropic Gaussian approximation to their posteriors. We develop an alternative merging procedure based on the Laplace approximation where we approximate each model's posterior as a Gaussian distribution whose precision matrix corresponds to its Fisher information. We first show that our "Fisher merging" technique provides a performance boost in settings where simple parameter averaging is currently used -- specifically, robust fine-tuning and model ensembling. Then, we compare merging to standard gradient-based transfer learning and demonstrate that merging enables a fundamentally different method for transferring capabilities across models. Specifically, we show that Fisher merging is competitive with gradient-based transfer learning approaches (while being significantly cheaper) in intermediate-task training and domain-adaptive pre-training. We also show that our merging procedure makes it possible to combine models in previously unexplored ways. We release our code to facilitate future research into methods for merging models.

Michael Matena, Colin Raffel• 2021

Related benchmarks

TaskDatasetResultRank
Image ClassificationStanford Cars
Accuracy69.2
635
Image ClassificationEuroSAT
Accuracy66.4
569
Natural Language UnderstandingGLUE
SST-264.7
531
Image ClassificationDTD
Accuracy59.9
485
Natural Language InferenceRTE
Accuracy83.3
448
Image ClassificationSUN397
Accuracy68.6
441
Image ClassificationMNIST
Accuracy87.9
398
Image ClassificationSVHN
Accuracy84.2
395
Commonsense ReasoningWinoGrande
Accuracy56.7
372
Image ClassificationRESISC45
Accuracy78.2
349
Showing 10 of 49 rows

Other info

Code

Follow for update