FedHPro: Federated Hyper-Prototype Learning via Gradient Matching
About
Federated Learning (FL) enables collaborative training of distributed clients while protecting privacy. To enhance generalization capability in FL, prototype-based FL is in the spotlight, since shared global prototypes offer semantic anchors for aligning client-specific local prototypes. However, existing methods update global prototypes at the prototype-level via averaging local prototypes or refining global anchors, which often leads to semantic drift across clients and subsequently yields a misaligned global signal. To alleviate this issue, we introduce hyper-prototypes, defined by a set of learnable global class-wise prototypes to preserve underlying semantic knowledge across clients. The hyper-prototypes are optimized via gradient matching to align with class-relevant characteristics distilled directly from clients' real samples, rather than prototype-level descriptors. We further propose FedHPro, a Federated Hyper-Prototype Learning framework, to leverage hyper-prototypes to promote inter-class separability via mutual-contrastive learning with client-specific margin, while encouraging intra-class uniformity through a consistency penalty. Comprehensive experiments under diverse heterogeneous scenarios confirm that 1) hyper-prototypes produce a more semantically consistent global signal, and 2) FedHPro achieves state-of-the-art performance on several benchmark datasets. Code is available at \href{https://github.com/mala-lab/FedHPro}{https://github.com/mala-lab/FedHPro}.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR-10 long-tailed (test) | -- | 211 | |
| Image Classification | Office-Caltech | Average Accuracy0.6452 | 44 | |
| Image Classification | SUN397 | Accuracy73.41 | 28 | |
| Image Classification | CIFAR10 (Non-IID 2) | Accuracy79.7 | 17 | |
| Image Classification | Tiny-ImageNet Non-IID 2 | Accuracy40.52 | 13 | |
| Federated Image Classification | Digits | Accuracy (MNIST)98.52 | 9 | |
| Federated Image Classification | Office-Caltech | Accuracy (Caltech)64.61 | 9 | |
| Image Classification | CIFAR10 NID1, α=0.2 | Accuracy85.98 | 9 | |
| Image Classification | CIFAR10 NID1, α=0.5 | Accuracy89.56 | 9 | |
| Image Classification | HAM10000 NID1, α=0.2 | Accuracy50.23 | 9 |