Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

About

Overparameterized neural networks can be highly accurate on average on an i.i.d. test set yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that hold on average but not in such groups). Distributionally robust optimization (DRO) allows us to learn models that instead minimize the worst-case training loss over a set of pre-defined groups. However, we find that naively applying group DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss also already has vanishing worst-case training loss. Instead, the poor worst-case performance arises from poor generalization on some groups. By coupling group DRO models with increased regularization---a stronger-than-typical L2 penalty or early stopping---we achieve substantially higher worst-group accuracies, with 10-40 percentage point improvements on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is important for worst-group generalization in the overparameterized regime, even if it is not needed for average generalization. Finally, we introduce a stochastic optimization algorithm, with convergence guarantees, to efficiently train group DRO models.

Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, Percy Liang• 2019

Related benchmarks

TaskDatasetResultRank
Node ClassificationCora
Accuracy90.54
1215
Node ClassificationPubmed
Accuracy85.17
363
Image ClassificationPACS (test)
Average Accuracy83.5
279
Image ClassificationPACS
Overall Average Accuracy84.4
270
Domain GeneralizationVLCS
Accuracy77.4
270
Domain GeneralizationPACS
Accuracy84.4
263
Domain GeneralizationOfficeHome
Accuracy66
234
Time Series ForecastingExchange
MSE0.821
227
Domain GeneralizationPACS (test)
Average Accuracy64
225
Node ClassificationCiteseer
Mean Accuracy82.64
202
Showing 10 of 315 rows
...

Other info

Follow for update