Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Towards Last-layer Retraining for Group Robustness with Fewer Annotations

About

Empirical risk minimization (ERM) of neural networks is prone to over-reliance on spurious correlations and poor generalization on minority groups. The recent deep feature reweighting (DFR) technique achieves state-of-the-art group robustness via simple last-layer retraining, but it requires held-out group and class annotations to construct a group-balanced reweighting dataset. In this work, we examine this impractical requirement and find that last-layer retraining can be surprisingly effective with no group annotations (other than for model selection) and only a handful of class annotations. We first show that last-layer retraining can greatly improve worst-group accuracy even when the reweighting dataset has only a small proportion of worst-group data. This implies a "free lunch" where holding out a subset of training data to retrain the last layer can substantially outperform ERM on the entire dataset with no additional data or annotations. To further improve group robustness, we introduce a lightweight method called selective last-layer finetuning (SELF), which constructs the reweighting dataset using misclassifications or disagreements. Our empirical and theoretical results present the first evidence that model disagreement upsamples worst-group data, enabling SELF to nearly match DFR on four well-established benchmarks across vision and language tasks with no group annotations and less than 3% of the held-out class annotations. Our code is available at https://github.com/tmlabonte/last-layer-retraining.

Tyler LaBonte, Vidya Muthukumar, Abhishek Kumar• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationWaterbirds (test)
Worst-Group Accuracy93
92
ClassificationCelebA (test)--
92
ClassificationCivilComments (test)
Worst-case Accuracy80.4
47
Natural Language InferenceMultiNLI (test)--
21
Showing 4 of 4 rows

Other info

Code

Follow for update