Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Improving Generalization in Federated Learning by Seeking Flat Minima

About

Models trained in federated settings often suffer from degraded performances and fail at generalizing, especially when facing heterogeneous scenarios. In this work, we investigate such behavior through the lens of geometry of the loss and Hessian eigenspectrum, linking the model's lack of generalization capacity to the sharpness of the solution. Motivated by prior studies connecting the sharpness of the loss surface and the generalization gap, we show that i) training clients locally with Sharpness-Aware Minimization (SAM) or its adaptive version (ASAM) and ii) averaging stochastic weights (SWA) on the server-side can substantially improve generalization in Federated Learning and help bridging the gap with centralized models. By seeking parameters in neighborhoods having uniform low loss, the model converges towards flatter minima and its generalization significantly improves in both homogeneous and heterogeneous scenarios. Empirical results demonstrate the effectiveness of those optimizers across a variety of benchmark vision datasets (e.g. CIFAR10/100, Landmarks-User-160k, IDDA) and tasks (large scale classification, semantic segmentation, domain generalization).

Debora Caldarola, Barbara Caputo, Marco Ciccone• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationMNIST (test)
Accuracy97.08
882
Image ClassificationCIFAR100 (test)
Accuracy49.17
112
Image ClassificationCIFAR10 centralized performance (test)
Accuracy76.98
104
Image ClassificationCIFAR100 alpha=0
Accuracy42.64
21
Image ClassificationCIFAR100 (alpha=0.5)
Accuracy49.17
21
Image ClassificationCIFAR100 alpha=1000
Accuracy54.97
21
Image ClassificationCIFAR10 alpha=0
Accuracy76.44
21
Image ClassificationCIFAR10 (alpha=0.05)
Accuracy76.98
21
Image ClassificationCIFAR10 alpha=100
Accuracy84.88
21
Image ClassificationCIFAR100-PAM
Accuracy55.44
21
Showing 10 of 20 rows

Other info

Code

Follow for update