Federated Representation Learning in the Under-Parameterized Regime
About
Federated representation learning (FRL) is a popular personalized federated learning (FL) framework where clients work together to train a common representation while retaining their personalized heads. Existing studies, however, largely focus on the over-parameterized regime. In this paper, we make the initial efforts to investigate FRL in the under-parameterized regime, where the FL model is insufficient to express the variations in all ground-truth models. We propose a novel FRL algorithm FLUTE, and theoretically characterize its sample complexity and convergence rate for linear models in the under-parameterized regime. To the best of our knowledge, this is the first FRL algorithm with provable performance guarantees in this regime. FLUTE features a data-independent random initialization and a carefully designed objective function that aids the distillation of subspace spanned by the global optimal representation from the misaligned local representations. On the technical side, we bridge low-rank matrix approximation techniques with the FL analysis, which may be of broad interest. We also extend FLUTE beyond linear representations. Experimental results demonstrate that FLUTE outperforms state-of-the-art FRL solutions in both synthetic and real-world tasks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR-100 Dir-0.1 | Accuracy31.61 | 65 | |
| Image Classification | CIFAR-10 Dir(0.5) | Accuracy48.25 | 59 | |
| Image Classification | EMNIST Dir(0.1) (test) | Test Accuracy80.32 | 41 | |
| Image Classification | CIFAR-100 Dir-0.5 | Accuracy12.63 | 37 | |
| Image Classification | EMNIST Dir(0.5) (test) | Test Accuracy63.87 | 31 | |
| Image Classification | MNIST (Dir(0.5)) | Accuracy0.8048 | 19 | |
| Image Classification | MEDMNISTA (Dir(0.1)) | Accuracy67.47 | 13 | |
| Image Classification | MEDMNISTA Dir(0.5) | Accuracy41.14 | 13 | |
| Image Classification | MEDMNISTC (Dir(0.1)) | Accuracy66.79 | 13 | |
| Image Classification | MEDMNISTC (Dir(0.5)) | Accuracy41.27 | 13 |