Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Domain Generalization using Causal Matching

About

In the domain generalization literature, a common objective is to learn representations independent of the domain after conditioning on the class label. We show that this objective is not sufficient: there exist counter-examples where a model fails to generalize to unseen domains even after satisfying class-conditional domain invariance. We formalize this observation through a structural causal model and show the importance of modeling within-class variations for generalization. Specifically, classes contain objects that characterize specific causal features, and domains can be interpreted as interventions on these objects that change non-causal features. We highlight an alternative condition: inputs across domains should have the same representation if they are derived from the same object. Based on this objective, we propose matching-based algorithms when base objects are observed (e.g., through data augmentation) and approximate the objective when objects are not observed (MatchDG). Our simple matching-based algorithms are competitive to prior work on out-of-domain accuracy for rotated MNIST, Fashion-MNIST, PACS, and Chest-Xray datasets. Our method MatchDG also recovers ground-truth object matches: on MNIST and Fashion-MNIST, top-10 matches from MatchDG have over 50% overlap with ground-truth matches.

Divyat Mahajan, Shruti Tople, Amit Sharma• 2020

Related benchmarks

TaskDatasetResultRank
Image ClassificationPACS (test)
Average Accuracy70.46
254
Image ClassificationPACS
Overall Average Accuracy84.57
230
Domain GeneralizationPACS
Accuracy (Art)81.32
221
Domain GeneralizationPACS (leave-one-domain-out)
Art Accuracy85.61
146
object recognitionPACS (leave-one-domain-out)
Acc (Art painting)85.61
112
Image ClassificationPACS v1 (test)
Average Accuracy87.52
92
Image ClassificationPACS (out-of-domain)
Overall Accuracy85.53
63
Node ClassificationCora Covariate shift (degree split)
OOD Accuracy62.91
50
Image ClassificationRotated-MNIST
Mean Accuracy97.8
40
Node ClassificationWebKB university split Covariate shift
OOD Test Accuracy35.45
30
Showing 10 of 27 rows

Other info

Code

Follow for update