Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Robust Fine-tuning of Zero-shot Models via Variance Reduction

About

When fine-tuning zero-shot models like CLIP, our desideratum is for the fine-tuned model to excel in both in-distribution (ID) and out-of-distribution (OOD). Recently, ensemble-based models (ESM) have been shown to offer significant robustness improvement, while preserving high ID accuracy. However, our study finds that ESMs do not solve the ID-OOD trade-offs: they achieve peak performance for ID and OOD accuracy at different mixing coefficients. When optimized for OOD accuracy, the ensemble model exhibits a noticeable decline in ID accuracy, and vice versa. In contrast, we propose a sample-wise ensembling technique that can simultaneously attain the best ID and OOD accuracy without the trade-offs. Specifically, we construct a Zero-Shot Failure (ZSF) set containing training samples incorrectly predicted by the zero-shot model. For each test sample, we calculate its distance to the ZSF set and assign a higher weight to the fine-tuned model in the ensemble if the distance is small. We term our method Variance Reduction Fine-tuning (VRF), as it effectively reduces the variance in ensemble predictions, thereby decreasing residual error. On ImageNet and five derived distribution shifts, our VRF further improves the OOD accuracy by 1.5 - 2.0 pp over the ensemble baselines while maintaining or increasing ID accuracy. VRF achieves similar large robustness gains (0.9 - 3.1 pp) on other distribution shifts benchmarks. Codes are available in https://github.com/BeierZhu/VRF.

Beier Zhu, Jiequan Cui, Hanwang Zhang• 2024

Related benchmarks

TaskDatasetResultRank
Image ClassificationImageNet V2--
487
Image ClassificationImageNet-1k (val)
Accuracy82.3
189
Image ClassificationObjectNet--
177
Image ClassificationImageNet Rendition
Top-1 Accuracy78.72
77
Image ClassificationImageNet-Sketch
Accuracy52.93
77
Image ClassificationCIFAR-10
Accuracy98.6
74
Image ClassificationImageNet OOD
ImageNet Acc61.8
55
Image ClassificationImageNet and Distribution Shifts
ImageNet-V2 Accuracy72.3
49
Image ClassificationImageNet-Adversarial
Top-1 Acc48.41
33
Image ClassificationImageNet Distribution Shifts Summary--
30
Showing 10 of 14 rows

Other info

Code

Follow for update