Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Invariance Principle Meets Information Bottleneck for Out-of-Distribution Generalization

About

The invariance principle from causality is at the heart of notable approaches such as invariant risk minimization (IRM) that seek to address out-of-distribution (OOD) generalization failures. Despite the promising theory, invariance principle-based approaches fail in common classification tasks, where invariant (causal) features capture all the information about the label. Are these failures due to the methods failing to capture the invariance? Or is the invariance principle itself insufficient? To answer these questions, we revisit the fundamental assumptions in linear regression tasks, where invariance-based approaches were shown to provably generalize OOD. In contrast to the linear regression tasks, we show that for linear classification tasks we need much stronger restrictions on the distribution shifts, or otherwise OOD generalization is impossible. Furthermore, even with appropriate restrictions on distribution shifts in place, we show that the invariance principle alone is insufficient. We prove that a form of the information bottleneck constraint along with invariance helps address key failures when invariant features capture all the information about the label and also retains the existing success when they do not. We propose an approach that incorporates both of these principles and demonstrate its effectiveness in several experiments.

Kartik Ahuja, Ethan Caballero, Dinghuai Zhang, Jean-Christophe Gagnon-Audet, Yoshua Bengio, Ioannis Mitliagkas, Irina Rish• 2021

Related benchmarks

TaskDatasetResultRank
Graph ClassificationTwitter
Accuracy61.26
57
Image ClassificationTerra Incognita (TerraInc)
Accuracy56.4
46
RegressionPovertyMap (test)
Worst-U/R Pearson Correlation0.43
43
Graph ClassificationDrugOOD Ki-Sca (Scaffold-based OOD shift)
ROC-AUC69.55
36
Graph ClassificationDrugOOD EC50 (Scaffold-based OOD shift)
ROC AUC62.62
36
RegressionACSIncome (test)
RMSE0.438
34
RegressionRCF-MNIST
RMSE (Avg)0.167
24
Graph ClassificationPROTEINS size shift (test)
MCC0.21
17
Graph ClassificationNCI109 size shift (test)
MCC0.15
17
Graph ClassificationDD size shift (test)
MCC0.15
17
Showing 10 of 63 rows

Other info

Code

Follow for update