Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect

About

As the class size grows, maintaining a balanced dataset across many classes is challenging because the data are long-tailed in nature; it is even impossible when the sample-of-interest co-exists with each other in one collectable unit, e.g., multiple visual instances in one image. Therefore, long-tailed classification is the key to deep learning at scale. However, existing methods are mainly based on re-weighting/re-sampling heuristics that lack a fundamental theory. In this paper, we establish a causal inference framework, which not only unravels the whys of previous methods, but also derives a new principled solution. Specifically, our theory shows that the SGD momentum is essentially a confounder in long-tailed classification. On one hand, it has a harmful causal effect that misleads the tail prediction biased towards the head. On the other hand, its induced mediation also benefits the representation learning and head prediction. Our framework elegantly disentangles the paradoxical effects of the momentum, by pursuing the direct causal effect caused by an input sample. In particular, we use causal intervention in training, and counterfactual reasoning in inference, to remove the "bad" while keep the "good". We achieve new state-of-the-arts on three long-tailed visual recognition benchmarks: Long-tailed CIFAR-10/-100, ImageNet-LT for image classification and LVIS for instance segmentation.

Kaihua Tang, Jianqiang Huang, Hanwang Zhang• 2020

Related benchmarks

TaskDatasetResultRank
Object DetectionLVIS v1.0 (val)
APbbox30
529
Image ClassificationiNaturalist 2018
Top-1 Accuracy64
291
Image ClassificationImageNet LT
Top-1 Accuracy51.8
264
Image ClassificationCIFAR-100 Long-Tailed (test)
Top-1 Accuracy59.6
234
Long-Tailed Image ClassificationImageNet-LT (test)
Top-1 Acc (Overall)52
220
Image ClassificationiNaturalist 2018 (test)
Top-1 Accuracy69.6
207
Image ClassificationCIFAR-10 long-tailed (test)
Top-1 Acc88.5
201
Instance SegmentationLVIS v1.0 (val)
AP (Rare)16
189
Image ClassificationImageNet-LT (test)
Top-1 Acc (All)52
159
Image ClassificationCIFAR-100 LT
Top-1 Acc59.6
131
Showing 10 of 68 rows

Other info

Code

Follow for update