Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Towards Understanding and Mitigating Dimensional Collapse in Heterogeneous Federated Learning

About

Federated learning aims to train models collaboratively across different clients without the sharing of data for privacy considerations. However, one major challenge for this learning paradigm is the {\em data heterogeneity} problem, which refers to the discrepancies between the local data distributions among various clients. To tackle this problem, we first study how data heterogeneity affects the representations of the globally aggregated models. Interestingly, we find that heterogeneous data results in the global model suffering from severe {\em dimensional collapse}, in which representations tend to reside in a lower-dimensional space instead of the ambient space. Moreover, we observe a similar phenomenon on models locally trained on each client and deduce that the dimensional collapse on the global model is inherited from local models. In addition, we theoretically analyze the gradient flow dynamics to shed light on how data heterogeneity result in dimensional collapse for local models. To remedy this problem caused by the data heterogeneity, we propose {\sc FedDecorr}, a novel method that can effectively mitigate dimensional collapse in federated learning. Specifically, {\sc FedDecorr} applies a regularization term during local training that encourages different dimensions of representations to be uncorrelated. {\sc FedDecorr}, which is implementation-friendly and computationally-efficient, yields consistent improvements over baselines on standard benchmark datasets. Code: https://github.com/bytedance/FedDecorr.

Yujun Shi, Jian Liang, Wenqing Zhang, Vincent Y. F. Tan, Song Bai• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR10 Cross-Silo (K=10) α = 0.1
Accuracy80.1
36
Image ClassificationCIFAR100 Cross-Silo (K=10) α = 0.1
Accuracy47.93
36
Image ClassificationCIFAR10 Cross-Device (K=100) α = 0.1
Accuracy58.68
36
Image ClassificationCIFAR100 Cross-Device (K=100) α = 0.1
Accuracy20.6
36
Federated LearningCIFAR-100 500 clients, 1% participation Dirichlet 0.3 (train test)
Accuracy (500 Rounds)31.03
13
Federated LearningCIFAR-10 500 clients, 1% participation Dirichlet 0.3 (train test)
Accuracy (500 Rounds)56.62
13
Federated LearningCIFAR-10 100 clients Dirichlet 0.3
Accuracy (Round 500)71.29
13
Federated LearningCIFAR-100 100 clients Dirichlet 0.3
Accuracy (500 Rounds)39.42
13
Image ClassificationCIFAR-10 Dirichlet 0.6, 100 clients, 5% participation (test)
Accuracy (500 Rounds)81.01
13
Image ClassificationCIFAR-100 i.i.d. 500 clients 2% participation (test)
Accuracy (500R)30.41
13
Showing 10 of 21 rows

Other info

Follow for update