Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Contrastive Adapters for Foundation Model Group Robustness

About

While large pretrained foundation models (FMs) have shown remarkable zero-shot classification robustness to dataset-level distribution shifts, their robustness to subpopulation or group shifts is relatively underexplored. We study this problem, and find that FMs such as CLIP may not be robust to various group shifts. Across 9 robustness benchmarks, zero-shot classification with their embeddings results in gaps of up to 80.7 percentage points (pp) between average and worst-group accuracy. Unfortunately, existing methods to improve robustness require retraining, which can be prohibitively expensive on large foundation models. We also find that efficient ways to improve model inference (e.g., via adapters, lightweight networks with FM embeddings as inputs) do not consistently improve and can sometimes hurt group robustness compared to zero-shot (e.g., increasing the accuracy gap by 50.1 pp on CelebA). We thus develop an adapter training strategy to effectively and efficiently improve FM group robustness. Our motivating observation is that while poor robustness results from groups in the same class being embedded far apart in the foundation model "embedding space," standard adapter training may not bring these points closer together. We thus propose contrastive adapting, which trains adapters with contrastive learning to bring sample embeddings close to both their ground-truth class embeddings and other sample embeddings in the same class. Across the 9 benchmarks, our approach consistently improves group robustness, raising worst-group accuracy by 8.5 to 56.0 pp over zero-shot. Our approach is also efficient, doing so without any FM finetuning and only a fixed set of frozen FM embeddings. On benchmarks such as Waterbirds and CelebA, this leads to worst-group accuracy comparable to state-of-the-art methods that retrain entire models, while only training $\leq$1% of the model parameters.

Michael Zhang, Christopher R\'e• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationWaterbirds
WG Accuracy86.9
79
Attribute ClassificationCelebA (test)
Worst-group Accuracy90
48
Group RobustnessCivilComments-WILDS (test)
WG Accuracy50.1
40
Image ClassificationCelebA
WG Score90
28
Object ClassificationWaterbirds (test)
Worst-Group Accuracy86.9
22
Group Robustness ClassificationBREEDS Living-17
WG Score80
16
Group Robustness ClassificationCIFAR-10.02
WG82.2
14
Image ClassificationBREEDS Living-17
WG Score62
8
Image ClassificationBREEDS Nonliving-26
WG Score55.3
8
Group RobustnessAMAZON-WILDS (test)
WG Accuracy87.9
7
Showing 10 of 13 rows

Other info

Follow for update