Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Fast Trainable Projection for Robust Fine-Tuning

About

Robust fine-tuning aims to achieve competitive in-distribution (ID) performance while maintaining the out-of-distribution (OOD) robustness of a pre-trained model when transferring it to a downstream task. Recently, projected gradient descent has been successfully used in robust fine-tuning by constraining the deviation from the initialization of the fine-tuned model explicitly through projection. However, algorithmically, two limitations prevent this method from being adopted more widely, scalability and efficiency. In this paper, we propose a new projection-based fine-tuning algorithm, Fast Trainable Projection (FTP) for computationally efficient learning of per-layer projection constraints, resulting in an average $35\%$ speedup on our benchmarks compared to prior works. FTP can be combined with existing optimizers such as AdamW, and be used in a plug-and-play fashion. Finally, we show that FTP is a special instance of hyper-optimizers that tune the hyper-parameters of optimizers in a learnable manner through nested differentiation. Empirically, we show superior robustness on OOD datasets, including domain shifts and natural corruptions, across four different vision tasks with five different pre-trained models. Additionally, we demonstrate that FTP is broadly applicable and beneficial to other learning scenarios such as low-label and continual learning settings thanks to its easy adaptability. The code will be available at https://github.com/GT-RIPL/FTP.git.

Junjiao Tian, Yen-Cheng Liu, James Seale Smith, Zsolt Kira• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationDomainNet Source: Real 100% data (test)
Accuracy (Real)84.22
15
Continual LearningImageNet-R 10 sequential tasks 200 classes
A1:N77.26
14
Image ClassificationImageNet Robustness Variants (Adversarial, Rendition, Sketch) V2 (test)
Accuracy (ID)84.19
10
Semantic segmentationPascal Semantic Segmentation ID Clean (test)
mIoU (Clean)73.79
9
Semantic segmentationPascal Semantic Segmentation OOD Corrupted (test)
mIoU (Fog)0.711
9
Human Parts SegmentationPASCAL Human Parts ID Clean (test)
mIoU65.5
8
Human Parts SegmentationPASCAL Human Parts OOD Corruptions (test)
Fog Acc61.73
8
Image ClassificationImageNet and OOD variants (ImV2, Im-A, Im-R, Im-S) 1.0 (val)
ImNet Acc0.8419
8
Semantic segmentationPASCAL-Context (Clean)
mIoU73.79
8
Semantic segmentationPASCAL-Context (OOD)
mIoU (Fog)71.1
8
Showing 10 of 13 rows

Other info

Code

Follow for update