Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Scaling Up Dataset Distillation to ImageNet-1K with Constant Memory

About

Dataset Distillation is a newly emerging area that aims to distill large datasets into much smaller and highly informative synthetic ones to accelerate training and reduce storage. Among various dataset distillation methods, trajectory-matching-based methods (MTT) have achieved SOTA performance in many tasks, e.g., on CIFAR-10/100. However, due to exorbitant memory consumption when unrolling optimization through SGD steps, MTT fails to scale to large-scale datasets such as ImageNet-1K. Can we scale this SOTA method to ImageNet-1K and does its effectiveness on CIFAR transfer to ImageNet-1K? To answer these questions, we first propose a procedure to exactly compute the unrolled gradient with constant memory complexity, which allows us to scale MTT to ImageNet-1K seamlessly with ~6x reduction in memory footprint. We further discover that it is challenging for MTT to handle datasets with a large number of classes, and propose a novel soft label assignment that drastically improves its convergence. The resulting algorithm sets new SOTA on ImageNet-1K: we can scale up to 50 IPCs (Image Per Class) on ImageNet-1K on a single GPU (all previous methods can only scale to 2 IPCs on ImageNet-1K), leading to the best accuracy (only 5.9% accuracy drop against full dataset training) while utilizing only 4.2% of the number of data points - an 18.2% absolute gain over prior SOTA. Our code is available at https://github.com/justincui03/tesla

Justin Cui, Ruochen Wang, Si Si, Cho-Jui Hsieh• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-100 (test)
Accuracy47.9
3518
Image ClassificationCIFAR-10 (test)
Accuracy72.6
3381
Image ClassificationImageNet-1K
Top-1 Acc7.7
836
Image ClassificationCIFAR-100 (val)--
661
Image ClassificationImageNet-1k (val)
Top-1 Accuracy27.9
512
Image ClassificationCIFAR-10 (val)
Top-1 Accuracy72.6
329
Dataset CondensationCIFAR-100 (train)
Training Speed (s/iter)5.71
8
Dataset CondensationTiny-ImageNet (train)
Training Speed (s/iter)42.01
5
Showing 8 of 8 rows

Other info

Follow for update