Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Fishr: Invariant Gradient Variances for Out-of-Distribution Generalization

About

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under controlled evaluation protocols. In this paper, we introduce a new regularization - named Fishr - that enforces domain invariance in the space of the gradients of the loss: specifically, the domain-level variances of gradients are matched across training domains. Our approach is based on the close relations between the gradient covariance, the Fisher Information and the Hessian of the loss: in particular, we show that Fishr eventually aligns the domain-level loss landscapes locally around the final weights. Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. Notably, Fishr improves the state of the art on the DomainBed benchmark and performs consistently better than Empirical Risk Minimization. Our code is available at https://github.com/alexrame/fishr.

Alexandre Rame, Corentin Dancette, Matthieu Cord• 2021

Related benchmarks

TaskDatasetResultRank
Domain GeneralizationVLCS
Accuracy77.8
238
Image ClassificationPACS
Overall Average Accuracy66.5
230
Domain GeneralizationPACS (test)
Average Accuracy66.1
225
Domain GeneralizationPACS--
221
Domain GeneralizationOfficeHome
Accuracy67.8
182
Image ClassificationOfficeHome
Average Accuracy68.6
131
Domain GeneralizationDomainBed
Average Accuracy65.7
127
Domain GeneralizationDomainNet
Accuracy41.7
113
object recognitionPACS (leave-one-domain-out)
Acc (Art painting)88.4
112
Domain GeneralizationDomainBed (test)
VLCS Accuracy77.8
110
Showing 10 of 51 rows

Other info

Code

Follow for update