Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Chroma-VAE: Mitigating Shortcut Learning with Generative Classifiers

About

Deep neural networks are susceptible to shortcut learning, using simple features to achieve low training loss without discovering essential semantic structure. Contrary to prior belief, we show that generative models alone are not sufficient to prevent shortcut learning, despite an incentive to recover a more comprehensive representation of the data than discriminative approaches. However, we observe that shortcuts are preferentially encoded with minimal information, a fact that generative models can exploit to mitigate shortcut learning. In particular, we propose Chroma-VAE, a two-pronged approach where a VAE classifier is initially trained to isolate the shortcut in a small latent subspace, allowing a secondary classifier to be trained on the complementary, shortcut-free latent subspace. In addition to demonstrating the efficacy of Chroma-VAE on benchmark and real-world shortcut learning tasks, our work highlights the potential for manipulating the latent space of generative classifiers to isolate or interpret specific correlations.

Wanqian Yang, Polina Kirichenko, Micah Goldblum, Andrew Gordon Wilson• 2022

Related benchmarks

TaskDatasetResultRank
Blond Hair classificationCelebA (test)
Average Group Accuracy82
30
Image ClassificationColoredMNIST 1.0 (Dout)
Accuracy72.4
8
Image ClassificationCelebA Synthetic Patch Adversarial Distribution shortcut (blond hair) (test)
Accuracy87.24
8
Image ClassificationColoredMNIST 1.0 (Din)
Accuracy89
8
Image ClassificationCelebA Synthetic Patch Neutral Distribution (Dneut) shortcut (blond hair) (test)
Accuracy92.33
8
Image ClassificationCelebA Synthetic Patch Training Distribution (Dtr) (test)
Accuracy99.64
8
Image ClassificationChest X-Ray (test)
Average Accuracy59.9
7
Image ClassificationCelebA Attractive Smiling (test)
Accuracy66.9
3
Image ClassificationMF-Dominoes (test)
Avg Accuracy78.5
3
Showing 9 of 9 rows

Other info

Code

Follow for update