Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Git Re-Basin: Merging Models modulo Permutation Symmetries

About

The success of deep learning is due in large part to our ability to solve certain massive non-convex optimization problems with relative ease. Though non-convex optimization is NP-hard, simple algorithms -- often variants of stochastic gradient descent -- exhibit surprising effectiveness in fitting large neural networks in practice. We argue that neural network loss landscapes often contain (nearly) a single basin after accounting for all possible permutation symmetries of hidden units a la Entezari et al. 2021. We introduce three algorithms to permute the units of one model to bring them into alignment with a reference model in order to merge the two models in weight space. This transformation produces a functionally equivalent set of weights that lie in an approximately convex basin near the reference model. Experimentally, we demonstrate the single basin phenomenon across a variety of model architectures and datasets, including the first (to our knowledge) demonstration of zero-barrier linear mode connectivity between independently trained ResNet models on CIFAR-10. Additionally, we identify intriguing phenomena relating model width and training time to mode connectivity. Finally, we discuss shortcomings of the linear mode connectivity hypothesis, including a counterexample to the single basin theory.

Samuel K. Ainsworth, Jonathan Hayase, Siddhartha Srinivasa• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 50+50
Joint Accuracy74.52
25
Image ClassificationCIFAR100 50+50
Joint Accuracy41.33
14
Image ClassificationCIFAR-100 50+50 (Joint)
Accuracy75.41
12
Image ClassificationCIFAR-100 Task A 50 classes
Accuracy87.46
12
Image ClassificationCIFAR-100 Task B 50 classes
Accuracy84.99
12
Node Classification RetentionCora, CiteSeer, Actor, Amazon-Ratings, and Arxiv specialist subsets (held-out)
Retention A83
10
Multi-task image classificationMNIST multi-task (test)
Accuracy71.71
9
Image ClassificationCIFAR-10 (test)
Test Loss Interpolation Barrier0.509
8
Image ClassificationDomainNet Same Label Space avg across pairs
Clipart Accuracy18.2
8
Image ClassificationCUB, Oxford-IIIT Pets, Stanford Dogs, NABirds Different Label Spaces (avg across pairs)
CUB Accuracy66.2
8
Showing 10 of 13 rows

Other info

Follow for update