Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Robust Learning with Progressive Data Expansion Against Spurious Correlation

About

While deep learning models have shown remarkable performance in various tasks, they are susceptible to learning non-generalizable spurious features rather than the core features that are genuinely correlated to the true label. In this paper, beyond existing analyses of linear models, we theoretically examine the learning process of a two-layer nonlinear convolutional neural network in the presence of spurious features. Our analysis suggests that imbalanced data groups and easily learnable spurious features can lead to the dominance of spurious features during the learning process. In light of this, we propose a new training algorithm called PDE that efficiently enhances the model's robustness for a better worst-group performance. PDE begins with a group-balanced subset of training data and progressively expands it to facilitate the learning of the core features. Experiments on synthetic and real-world benchmark datasets confirm the superior performance of our method on models such as ResNets and Transformers. On average, our method achieves a 2.8% improvement in worst-group accuracy compared with the state-of-the-art method, while enjoying up to 10x faster training efficiency. Codes are available at https://github.com/uclaml/PDE.

Yihe Deng, Yu Yang, Baharan Mirzasoleiman, Quanquan Gu• 2023

Related benchmarks

TaskDatasetResultRank
ClassificationCelebA
Avg Accuracy92.4
185
Image ClassificationWaterbirds
Average Accuracy92.4
157
Image ClassificationWaterbirds (test)
Worst-Group Accuracy90.3
112
ClassificationCelebA (test)
Average Accuracy55
92
Image ClassificationCelebA (test)
Accuracy92.4
57
Image ClassificationCMNIST (test)
Test Accuracy1.3
55
Image ClassificationMetaShift
Average Accuracy87.4
33
Image ClassificationMetaShift (test)
Average Accuracy87.4
27
Natural Language InferenceMultiNLI
Accuracy69.1
23
Image ClassificationColorMNIST (ρ = 80%) (test)
Average Accuracy76.6
20
Showing 10 of 21 rows

Other info

Follow for update