Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Domain Generalization using Causal Matching

About

In the domain generalization literature, a common objective is to learn representations independent of the domain after conditioning on the class label. We show that this objective is not sufficient: there exist counter-examples where a model fails to generalize to unseen domains even after satisfying class-conditional domain invariance. We formalize this observation through a structural causal model and show the importance of modeling within-class variations for generalization. Specifically, classes contain objects that characterize specific causal features, and domains can be interpreted as interventions on these objects that change non-causal features. We highlight an alternative condition: inputs across domains should have the same representation if they are derived from the same object. Based on this objective, we propose matching-based algorithms when base objects are observed (e.g., through data augmentation) and approximate the objective when objects are not observed (MatchDG). Our simple matching-based algorithms are competitive to prior work on out-of-domain accuracy for rotated MNIST, Fashion-MNIST, PACS, and Chest-Xray datasets. Our method MatchDG also recovers ground-truth object matches: on MNIST and Fashion-MNIST, top-10 matches from MatchDG have over 50% overlap with ground-truth matches.

Divyat Mahajan, Shruti Tople, Amit Sharma• 2020

Related benchmarks

TaskDatasetResultRank
Image ClassificationPACS (test)
Average Accuracy70.46
271
Image ClassificationPACS
Overall Average Accuracy84.57
241
Domain GeneralizationPACS--
231
Domain GeneralizationPACS (leave-one-domain-out)
Art Accuracy85.61
152
object recognitionPACS (leave-one-domain-out)
Acc (Art painting)85.61
112
Image ClassificationPACS v1 (test)
Average Accuracy87.52
92
Image ClassificationPACS (out-of-domain)
Overall Accuracy85.53
63
Node ClassificationCora Covariate shift (degree split)
OOD Accuracy62.91
50
Image ClassificationRotated-MNIST
Mean Accuracy97.8
40
Node ClassificationWebKB university split Covariate shift
OOD Test Accuracy35.45
30
Showing 10 of 27 rows

Other info

Code

Follow for update