RETRIEVE: Coreset Selection for Efficient and Robust Semi-Supervised Learning
About
Semi-supervised learning (SSL) algorithms have had great success in recent years in limited labeled data regimes. However, the current state-of-the-art SSL algorithms are computationally expensive and entail significant compute time and energy requirements. This can prove to be a huge limitation for many smaller companies and academic groups. Our main insight is that training on a subset of unlabeled data instead of entire unlabeled data enables the current SSL algorithms to converge faster, significantly reducing computational costs. In this work, we propose RETRIEVE, a coreset selection framework for efficient and robust semi-supervised learning. RETRIEVE selects the coreset by solving a mixed discrete-continuous bi-level optimization problem such that the selected coreset minimizes the labeled set loss. We use a one-step gradient approximation and show that the discrete optimization problem is approximately submodular, enabling simple greedy algorithms to obtain the coreset. We empirically demonstrate on several real-world datasets that existing SSL algorithms like VAT, Mean-Teacher, FixMatch, when used with RETRIEVE, achieve a) faster training times, b) better performance when unlabeled data consists of Out-of-Distribution (OOD) data and imbalance. More specifically, we show that with minimal accuracy degradation, RETRIEVE achieves a speedup of around $3\times$ in the traditional SSL setting and achieves a speedup of $5\times$ compared to state-of-the-art (SOTA) robust SSL algorithms in the case of imbalance and OOD data. RETRIEVE is available as a part of the CORDS toolkit: https://github.com/decile-team/cords.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | CIFAR-100 (test) | Accuracy31.2 | 3518 | |
| Image Classification | CIFAR-10 (test) | Accuracy53.8 | 3381 | |
| Image Classification | FashionMNIST (test) | -- | 218 | |
| Image Classification | F-MNIST (test) | Accuracy60.4 | 64 | |
| Image Classification | ImageNet-10 (test) | Accuracy95.7 | 42 | |
| Image Classification | ImageNet-50 (test) | Test Accuracy27.7 | 39 | |
| Image Classification | ImageNet 10/50-class | Accuracy59.3 | 8 | |
| Semi-Supervised Learning | CIFAR10 OOD (test) | Top-1 Acc (25%)0.7926 | 5 | |
| Semi-Supervised Learning | MNIST OOD (test) | Top-1 Acc (25%)97.3 | 5 | |
| Semi-Supervised Learning | CIFAR10 Imbalance (test) | Top-1 Acc (10%)66.88 | 5 |