Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

RETRIEVE: Coreset Selection for Efficient and Robust Semi-Supervised Learning

About

Semi-supervised learning (SSL) algorithms have had great success in recent years in limited labeled data regimes. However, the current state-of-the-art SSL algorithms are computationally expensive and entail significant compute time and energy requirements. This can prove to be a huge limitation for many smaller companies and academic groups. Our main insight is that training on a subset of unlabeled data instead of entire unlabeled data enables the current SSL algorithms to converge faster, significantly reducing computational costs. In this work, we propose RETRIEVE, a coreset selection framework for efficient and robust semi-supervised learning. RETRIEVE selects the coreset by solving a mixed discrete-continuous bi-level optimization problem such that the selected coreset minimizes the labeled set loss. We use a one-step gradient approximation and show that the discrete optimization problem is approximately submodular, enabling simple greedy algorithms to obtain the coreset. We empirically demonstrate on several real-world datasets that existing SSL algorithms like VAT, Mean-Teacher, FixMatch, when used with RETRIEVE, achieve a) faster training times, b) better performance when unlabeled data consists of Out-of-Distribution (OOD) data and imbalance. More specifically, we show that with minimal accuracy degradation, RETRIEVE achieves a speedup of around $3\times$ in the traditional SSL setting and achieves a speedup of $5\times$ compared to state-of-the-art (SOTA) robust SSL algorithms in the case of imbalance and OOD data. RETRIEVE is available as a part of the CORDS toolkit: https://github.com/decile-team/cords.

Krishnateja Killamsetty, Xujiang Zhao, Feng Chen, Rishabh Iyer• 2021

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)
Accuracy31.2
3518
Image ClassificationCIFAR-10 (test)
Accuracy53.8
3381
Image ClassificationFashionMNIST (test)--
218
Image ClassificationF-MNIST (test)
Accuracy60.4
64
Image ClassificationImageNet-10 (test)
Accuracy95.7
42
Image ClassificationImageNet-50 (test)
Test Accuracy27.7
39
Image ClassificationImageNet 10/50-class
Accuracy59.3
8
Semi-Supervised LearningCIFAR10 OOD (test)
Top-1 Acc (25%)0.7926
5
Semi-Supervised LearningMNIST OOD (test)
Top-1 Acc (25%)97.3
5
Semi-Supervised LearningCIFAR10 Imbalance (test)
Top-1 Acc (10%)66.88
5
Showing 10 of 10 rows

Other info

Code

Follow for update