Robust Generalization against Photon-Limited Corruptions via Worst-Case Sharpness Minimization
About
Robust generalization aims to tackle the most challenging data distributions which are rare in the training set and contain severe noises, i.e., photon-limited corruptions. Common solutions such as distributionally robust optimization (DRO) focus on the worst-case empirical risk to ensure low training error on the uncommon noisy distributions. However, due to the over-parameterized model being optimized on scarce worst-case data, DRO fails to produce a smooth loss landscape, thus struggling on generalizing well to the test set. Therefore, instead of focusing on the worst-case risk minimization, we propose SharpDRO by penalizing the sharpness of the worst-case distribution, which measures the loss changes around the neighbor of learning parameters. Through worst-case sharpness minimization, the proposed method successfully produces a flat loss curve on the corrupted distributions, thus achieving robust generalization. Moreover, by considering whether the distribution annotation is available, we apply SharpDRO to two problem settings and design a worst-case selection process for robust generalization. Theoretically, we show that SharpDRO has a great convergence guarantee. Experimentally, we simulate photon-limited corruptions using CIFAR10/100 and ImageNet30 datasets and show that SharpDRO exhibits a strong generalization ability against severe corruptions and exceeds well-known baseline methods with large performance gains.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Out-of-Distribution Detection and Generalization | CIFAR-10 ID LSUN-C semantic OOD & CIFAR-10-C covariate OOD | OOD Accuracy79.03 | 74 | |
| Out-of-Distribution Detection and Generalization | CIFAR-10 ID SVHN semantic OOD CIFAR-10-C covariate OOD | OOD Accuracy79.03 | 38 | |
| Generalized OOD Detection | CIFAR-10 with Places365 (semantic OOD) and CIFAR-10-C (covariate OOD) (test) | OOD Accuracy79.03 | 38 | |
| Out-of-Distribution Detection and Generalization | CIFAR-10 ID Textures semantic OOD CIFAR-10-C covariate OOD | OOD Accuracy79.03 | 38 | |
| OOD Detection | CIFAR-10 LSUN-Resize semantic OOD + CIFAR-10-C covariate OOD (test) | FPR13.27 | 22 |