Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Enhancing Domain Adaptation through Prompt Gradient Alignment

About

Prior Unsupervised Domain Adaptation (UDA) methods often aim to train a domain-invariant feature extractor, which may hinder the model from learning sufficiently discriminative features. To tackle this, a line of works based on prompt learning leverages the power of large-scale pre-trained vision-language models to learn both domain-invariant and specific features through a set of domain-agnostic and domain-specific learnable prompts. Those studies typically enforce invariant constraints on representation, output, or prompt space to learn such prompts. In contrast, we cast UDA as a multiple-objective optimization problem in which each objective is represented by a domain loss. Under this new framework, we propose to align per-objective gradients to foster consensus between them. Additionally, to prevent potential overfitting when fine-tuning this deep learning architecture, we penalize the norm of these gradients. To achieve these goals, we devise a practical gradient update procedure that can work under both single-source and multi-source UDA. Empirically, our method consistently outperforms other vision-language model adaptation methods. The implementation is available at https://github.com/VietHoang1512/PGA.

Hoang Phan, Lam Tran, Quyen Tran, Trung Le• 2024

Related benchmarks

TaskDatasetResultRank
Unsupervised Domain AdaptationOffice-Home (test)
Average Accuracy73.9
332
Unsupervised Domain AdaptationOffice-Home
Average Accuracy79.4
238
Image ClassificationDomainNet (test)
Average Accuracy55.4
209
Image ClassificationOfficeHome
Average Accuracy88.4
131
Domain AdaptationOffice-Home (test)
Mean Accuracy89.4
112
Unsupervised Domain AdaptationDomainNet
Average Accuracy56.2
100
Unsupervised Domain AdaptationOffice-Home 101 (test)
Accuracy (Ar→Cl)56.1
17
Image ClassificationImageCLEF
F1 Score94.2
14
Unsupervised Domain AdaptationImageCLEF
Accuracy (Domain C)97.4
12
Image ClassificationS2RDA-49 synthetic-to-real (test)
Accuracy74.1
9
Showing 10 of 12 rows

Other info

Follow for update