Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Spread Spurious Attribute: Improving Worst-group Accuracy with Spurious Attribute Estimation

About

The paradigm of worst-group loss minimization has shown its promise in avoiding to learn spurious correlations, but requires costly additional supervision on spurious attributes. To resolve this, recent works focus on developing weaker forms of supervision -- e.g., hyperparameters discovered with a small number of validation samples with spurious attribute annotation -- but none of the methods retain comparable performance to methods using full supervision on the spurious attribute. In this paper, instead of searching for weaker supervisions, we ask: Given access to a fixed number of samples with spurious attribute annotations, what is the best achievable worst-group loss if we "fully exploit" them? To this end, we propose a pseudo-attribute-based algorithm, coined Spread Spurious Attribute (SSA), for improving the worst-group accuracy. In particular, we leverage samples both with and without spurious attribute annotations to train a model to predict the spurious attribute, then use the pseudo-attribute predicted by the trained model as supervision on the spurious attribute to train a new robust model having minimal worst-group loss. Our experiments on various benchmark datasets show that our algorithm consistently outperforms the baseline methods using the same number of validation samples with spurious attribute annotations. We also demonstrate that the proposed SSA can achieve comparable performances to methods using full (100%) spurious attribute supervision, by using a much smaller number of annotated samples -- from 0.6% and up to 1.5%, depending on the dataset.

Junhyun Nam, Jaehyung Kim, Jaeho Lee, Jinwoo Shin• 2022

Related benchmarks

TaskDatasetResultRank
Attribute ClassificationCelebA (test)
Worst-group Accuracy89.8
48
Object ClassificationWaterbirds (test)
Worst-Group Accuracy89
22
Natural Language InferenceMultiNLI (test)
Accuracy80.4
21
Toxicity DetectionCivilComments-WILDS (test)
Average Accuracy89.1
19
Image ClassificationCelebA original unshifted
Worst Acc89.8
10
Image ClassificationWaterbirds original unshifted
Worst Accuracy89
10
ClassificationCelebA HairColor (test)
In-Distribution Accuracy92.8
9
Image ClassificationCMNIST original unshifted
Worst Accuracy71.1
9
Showing 8 of 8 rows

Other info

Follow for update