Dataset Distillation for Pre-Trained Self-Supervised Vision Models
About
The task of dataset distillation aims to find a small set of synthetic images such that training a model on them reproduces the performance of the same model trained on a much larger dataset of real samples. Existing distillation methods focus on synthesizing datasets that enable training randomly initialized models. In contrast, state-of-the-art vision approaches are increasingly building on large, pre-trained self-supervised models rather than training from scratch. In this paper, we investigate the problem of distilling datasets that enable us to optimally train linear probes on top of such large, pre-trained vision models. We introduce a method of dataset distillation for this task called Linear Gradient Matching that optimizes the synthetic images such that, when passed through a pre-trained feature extractor, they induce gradients in the linear classifier similar to those produced by the real data. Our method yields synthetic data that outperform all real-image baselines and, remarkably, generalize across pre-trained vision models, enabling us, for instance, to train a linear CLIP probe that performs competitively using a dataset distilled via a DINO backbone. Further, we show that our distilled datasets are exceptionally effective for fine-grained classification and provide a valuable tool for model interpretability, predicting, among other things, how similar two models' embedding spaces are under the platonic representation hypothesis or whether a model is sensitive to spurious correlations in adversarial datasets.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Image Classification | Waterbirds | Average Accuracy79 | 209 | |
| Image Classification | Spawrious | Accuracy80.8 | 28 | |
| Image Classification | CUB-200 1.0 (test) | Average Accuracy66.2 | 16 | |
| Image Classification | ImageNet-100 1.0 (test) | CLIP Score84.9 | 8 | |
| Image Classification | Stanford Dogs 1.0 (test) | CLIP Score52.1 | 8 | |
| Dataset Distillation | ImageNet 100 IPC=1 (train val) | DINOv2 Score91.4 | 7 | |
| Dataset Distillation | ImageNet-1K IPC=1 (train val) | DINOv2 Accuracy75 | 6 | |
| Dataset Distillation | ImageNet-100 | Time (s)4.25 | 3 |