Fishr: Invariant Gradient Variances for Out-of-Distribution Generalization
About
Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under controlled evaluation protocols. In this paper, we introduce a new regularization - named Fishr - that enforces domain invariance in the space of the gradients of the loss: specifically, the domain-level variances of gradients are matched across training domains. Our approach is based on the close relations between the gradient covariance, the Fisher Information and the Hessian of the loss: in particular, we show that Fishr eventually aligns the domain-level loss landscapes locally around the final weights. Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. Notably, Fishr improves the state of the art on the DomainBed benchmark and performs consistently better than Empirical Risk Minimization. Our code is available at https://github.com/alexrame/fishr.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Domain Generalization | VLCS | Accuracy77.8 | 238 | |
| Image Classification | PACS | Overall Average Accuracy66.5 | 230 | |
| Domain Generalization | PACS (test) | Average Accuracy66.1 | 225 | |
| Domain Generalization | PACS | -- | 221 | |
| Domain Generalization | OfficeHome | Accuracy67.8 | 182 | |
| Image Classification | OfficeHome | Average Accuracy68.6 | 131 | |
| Domain Generalization | DomainBed | Average Accuracy65.7 | 127 | |
| Domain Generalization | DomainNet | Accuracy41.7 | 113 | |
| object recognition | PACS (leave-one-domain-out) | Acc (Art painting)88.4 | 112 | |
| Domain Generalization | DomainBed (test) | VLCS Accuracy77.8 | 110 |