Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

MaskTune: Mitigating Spurious Correlations by Forcing to Explore

About

A fundamental challenge of over-parameterized deep learning models is learning meaningful data representations that yield good performance on a downstream task without over-fitting spurious input features. This work proposes MaskTune, a masking strategy that prevents over-reliance on spurious (or a limited number of) features. MaskTune forces the trained model to explore new features during a single epoch finetuning by masking previously discovered features. MaskTune, unlike earlier approaches for mitigating shortcut learning, does not require any supervision, such as annotating spurious features or labels for subgroup samples in a dataset. Our empirical results on biased MNIST, CelebA, Waterbirds, and ImagenNet-9L datasets show that MaskTune is effective on tasks that often suffer from the existence of spurious correlations. Finally, we show that MaskTune outperforms or achieves similar performance to the competing methods when applied to the selective classification (classification with rejection option) task. Code for MaskTune is available at https://github.com/aliasgharkhani/Masktune.

Saeid Asgari Taghanaki, Aliasghar Khani, Fereshte Khani, Ali Gholami, Linh Tran, Ali Mahdavi-Amiri, Ghassan Hamarneh• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationWaterbirds
Average Accuracy93
209
Image ClassificationWaterbirds (test)
Worst-Group Accuracy86.4
127
ClassificationCelebA (test)
Average Accuracy91.3
92
Image ClassificationISIC (test)--
24
Image ClassificationCelebA WILDS (test)
I.I.D. Accuracy91.3
19
Image ClassificationWaterbirds 95% (test)
Worst-Group Accuracy66.6
18
ClassificationImageNet-9 Backgrounds Challenge
Accuracy (Original IN-9)95.6
17
Image ClassificationKnee (test)
WGA54.3
16
Showing 8 of 8 rows

Other info

Follow for update