Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Gradient Masked Averaging for Federated Learning

About

Federated learning (FL) is an emerging paradigm that permits a large number of clients with heterogeneous data to coordinate learning of a unified global model without the need to share data amongst each other. A major challenge in federated learning is the heterogeneity of data across client, which can degrade the performance of standard FL algorithms. Standard FL algorithms involve averaging of model parameters or gradient updates to approximate the global model at the server. However, we argue that in heterogeneous settings, averaging can result in information loss and lead to poor generalization due to the bias induced by dominant client gradients. We hypothesize that to generalize better across non-i.i.d datasets, the algorithms should focus on learning the invariant mechanism that is constant while ignoring spurious mechanisms that differ across clients. Inspired from recent works in Out-of-Distribution generalization, we propose a gradient masked averaging approach for FL as an alternative to the standard averaging of client updates. This aggregation technique for client updates can be adapted as a drop-in replacement in most existing federated algorithms. We perform extensive experiments on multiple FL algorithms with in-distribution, real-world, feature-skewed out-of-distribution, and quantity imbalanced datasets and show that it provides consistent improvements, particularly in the case of heterogeneous clients.

Irene Tenison, Sai Aravind Sreeramadas, Vaikkunth Mugunthan, Edouard Oyallon, Irina Rish, Eugene Belilovsky• 2022

Related benchmarks

TaskDatasetResultRank
Domain GeneralizationPACS
Accuracy (Art)83.88
221
Domain GeneralizationOffice-Home
Average Accuracy68.25
63
Domain GeneralizationVLCS (test)
Average Accuracy77.32
62
Showing 3 of 3 rows

Other info

Follow for update