Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Just Train Twice: Improving Group Robustness without Training Group Information

About

Standard training via empirical risk minimization (ERM) can produce models that achieve high accuracy on average but low accuracy on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve high worst-group accuracy, like group distributionally robust optimization (group DRO) require expensive group annotations for each training point, whereas approaches that do not use such group annotations typically achieve unsatisfactory worst-group accuracy. In this paper, we propose a simple two-stage approach, JTT, that first trains a standard ERM model for several epochs, and then trains a second model that upweights the training examples that the first model misclassified. Intuitively, this upweights examples from groups on which standard ERM models perform poorly, leading to improved worst-group performance. Averaged over four image classification and natural language processing tasks with spurious correlations, JTT closes 75% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.

Evan Zheran Liu, Behzad Haghgoo, Annie S. Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, Chelsea Finn• 2021

Related benchmarks

TaskDatasetResultRank
Natural Language InferenceSNLI (test)
Accuracy76.25
681
ClassificationCelebA
Avg Accuracy88
137
Sentiment AnalysisSST-2 (test)
Accuracy80.82
136
Natural Language InferenceMNLI (matched)
Accuracy52
110
ClassificationCelebA (test)
Average Accuracy95.9
92
Image ClassificationWaterbirds (test)
Worst-Group Accuracy86.7
92
Image ClassificationWaterbirds
WG Accuracy86
79
Natural Language InferenceSNLI (dev)
Accuracy76.96
71
Image ClassificationCMNIST (test)
Test Accuracy85.03
55
Attribute ClassificationCelebA (test)
Worst-group Accuracy81.1
48
Showing 10 of 79 rows
...

Other info

Follow for update