Domain Generalization using Causal Matching
About
In the domain generalization literature, a common objective is to learn representations independent of the domain after conditioning on the class label. We show that this objective is not sufficient: there exist counter-examples where a model fails to generalize to unseen domains even after satisfying class-conditional domain invariance. We formalize this observation through a structural causal model and show the importance of modeling within-class variations for generalization. Specifically, classes contain objects that characterize specific causal features, and domains can be interpreted as interventions on these objects that change non-causal features. We highlight an alternative condition: inputs across domains should have the same representation if they are derived from the same object. Based on this objective, we propose matching-based algorithms when base objects are observed (e.g., through data augmentation) and approximate the objective when objects are not observed (MatchDG). Our simple matching-based algorithms are competitive to prior work on out-of-domain accuracy for rotated MNIST, Fashion-MNIST, PACS, and Chest-Xray datasets. Our method MatchDG also recovers ground-truth object matches: on MNIST and Fashion-MNIST, top-10 matches from MatchDG have over 50% overlap with ground-truth matches.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | PACS (test) | Average Accuracy70.46 | 254 | |
| Image Classification | PACS | Overall Average Accuracy84.57 | 230 | |
| Domain Generalization | PACS | Accuracy (Art)81.32 | 221 | |
| Domain Generalization | PACS (leave-one-domain-out) | Art Accuracy85.61 | 146 | |
| object recognition | PACS (leave-one-domain-out) | Acc (Art painting)85.61 | 112 | |
| Image Classification | PACS v1 (test) | Average Accuracy87.52 | 92 | |
| Image Classification | PACS (out-of-domain) | Overall Accuracy85.53 | 63 | |
| Node Classification | Cora Covariate shift (degree split) | OOD Accuracy62.91 | 50 | |
| Image Classification | Rotated-MNIST | Mean Accuracy97.8 | 40 | |
| Node Classification | WebKB university split Covariate shift | OOD Test Accuracy35.45 | 30 |