Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Test-Time Distribution Normalization for Contrastively Learned Vision-language Models

About

Advances in the field of vision-language contrastive learning have made it possible for many downstream applications to be carried out efficiently and accurately by simply taking the dot product between image and text representations. One of the most representative approaches proposed recently known as CLIP has garnered widespread adoption due to its effectiveness. CLIP is trained with an InfoNCE loss that takes into account both positive and negative samples to help learn a much more robust representation space. This paper reveals that the common downstream practice of taking a dot product is only a zeroth-order approximation of the optimization goal, resulting in a loss of information during test-time. Intuitively, since the model has been optimized based on the InfoNCE loss, test-time procedures should also be in alignment. The question lies in how one can retrieve any semblance of negative samples information during inference in a computationally efficient way. To this end, we propose Distribution Normalization (DN), where we approximate the mean representation of a batch of test samples and use such a mean to represent what would be analogous to negative samples in the InfoNCE loss. DN requires no retraining or fine-tuning and can be effortlessly applied during inference. Extensive experiments on a wide variety of downstream tasks exhibit a clear advantage of DN over the dot product on top of other existing test-time augmentation methods.

Yifei Zhou, Juntao Ren, Fengyu Li, Ramin Zabih, Ser-Nam Lim• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationEuroSAT
Accuracy53.3
497
Image ClassificationDTD
Accuracy45.7
487
Image ClassificationUCF101
Top-1 Acc68.4
404
ClassificationCars
Accuracy64
314
Image ClassificationCUB
Accuracy56.1
249
Image ClassificationFGVCAircraft
Accuracy24.3
225
Image ClassificationPets
Accuracy87.7
204
Image ClassificationFlowers
Accuracy68
127
Image Classification11 Downstream Classification Datasets (ImageNet, Flowers102, DTD, OxfordPets, StanfordCars, UCF101, Caltech101, Food101, SUN397, FGVC-Aircraft, EuroSAT) standard (test)
DTD Accuracy41.21
115
Image ClassificationCaltech
Accuracy93.6
98
Showing 10 of 18 rows

Other info

Code

Follow for update