Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Simple and Fast Group Robustness by Automatic Feature Reweighting

About

A major challenge to out-of-distribution generalization is reliance on spurious features -- patterns that are predictive of the class label in the training data distribution, but not causally related to the target. Standard methods for reducing the reliance on spurious features typically assume that we know what the spurious feature is, which is rarely true in the real world. Methods that attempt to alleviate this limitation are complex, hard to tune, and lead to a significant computational overhead compared to standard training. In this paper, we propose Automatic Feature Reweighting (AFR), an extremely simple and fast method for updating the model to reduce the reliance on spurious features. AFR retrains the last layer of a standard ERM-trained base model with a weighted loss that emphasizes the examples where the ERM model predicts poorly, automatically upweighting the minority group without group labels. With this simple procedure, we improve upon the best reported results among competing methods trained without spurious attributes on several vision and natural language classification benchmarks, using only a fraction of their compute.

Shikai Qiu, Andres Potapczynski, Pavel Izmailov, Andrew Gordon Wilson• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationWaterbirds (test)
Worst-Group Accuracy90.4
127
ClassificationCelebA (test)
Average Accuracy91.3
92
ClassificationCivilComments (test)
Average Accuracy89.8
51
Group RobustnessCivilComments-WILDS (test)
WG Accuracy38.4
40
Gender ClassificationCOCO 95% spurious correlation
Average Score73.3
24
Image ClassificationWaterbirds 95% correlation (test)
Worst-group Accuracy88.9
23
Group RobustnessCheXpert (test)
WGA63.9
22
Image ClassificationCelebA WILDS (test)
I.I.D. Accuracy91.3
19
Text ClassificationMultiNLI (test)
WGA73.4
18
Text ClassificationCivilComments-WILDS (test)
Accuracy89.8
13
Showing 10 of 11 rows

Other info

Follow for update