Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

PGrad: Learning Principal Gradients For Domain Generalization

About

Machine learning models fail to perform when facing out-of-distribution (OOD) domains, a challenging task known as domain generalization (DG). In this work, we develop a novel DG training strategy, we call PGrad, to learn a robust gradient direction, improving models' generalization ability on unseen domains. The proposed gradient aggregates the principal directions of a sampled roll-out optimization trajectory that measures the training dynamics across all training domains. PGrad's gradient design forces the DG training to ignore domain-dependent noise signals and updates all training domains with a robust direction covering main components of parameter dynamics. We further improve PGrad via bijection-based computational refinement and directional plus length-based calibrations. Our theoretical proof connects PGrad to the spectral analysis of Hessian in training neural networks. Experiments on DomainBed and WILDS benchmarks demonstrate that our approach effectively enables robust DG optimization and leads to smoothly decreased loss curves. Empirically, PGrad achieves competitive results across seven datasets, demonstrating its efficacy across both synthetic and real-world distributional shifts. Code is available at https://github.com/QData/PGrad.

Zhe Wang, Jake Grigsby, Yanjun Qi• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationPACS
Overall Average Accuracy70.6
230
Domain GeneralizationPACS
Accuracy (Art)87.6
221
Domain GeneralizationOffice-Home
Average Accuracy69.3
63
Domain GeneralizationVLCS
Accuracy (L)64.4
27
Image ClassificationOfficeHome
Average Accuracy55.4
24
Image ClassificationVLCS GINIDG setting (test)
Average Accuracy77.6
24
Image ClassificationPACS GINIDG setting (test)
Accuracy (Overall)0.843
24
Image ClassificationOfficeHome TotalHeavyTail setting (test)
Avg Accuracy48
24
Image ClassificationVLCS
Average Accuracy55.7
24
Image ClassificationVLCS TotalHeavyTail setting (test)
Average Accuracy72.7
24
Showing 10 of 13 rows

Other info

Follow for update