Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Inducing Neural Collapse in Imbalanced Learning: Do We Really Need a Learnable Classifier at the End of Deep Neural Network?

About

Modern deep neural networks for classification usually jointly learn a backbone for representation and a linear classifier to output the logit of each class. A recent study has shown a phenomenon called neural collapse that the within-class means of features and the classifier vectors converge to the vertices of a simplex equiangular tight frame (ETF) at the terminal phase of training on a balanced dataset. Since the ETF geometric structure maximally separates the pair-wise angles of all classes in the classifier, it is natural to raise the question, why do we spend an effort to learn a classifier when we know its optimal geometric structure? In this paper, we study the potential of learning a neural network for classification with the classifier randomly initialized as an ETF and fixed during training. Our analytical work based on the layer-peeled model indicates that the feature learning with a fixed ETF classifier naturally leads to the neural collapse state even when the dataset is imbalanced among classes. We further show that in this case the cross entropy (CE) loss is not necessary and can be replaced by a simple squared loss that shares the same global optimality but enjoys a better convergence property. Our experimental results show that our method is able to bring significant improvements with faster convergence on multiple imbalanced datasets.

Yibo Yang, Shixiang Chen, Xiangtai Li, Liang Xie, Zhouchen Lin, Dacheng Tao• 2022

Related benchmarks

TaskDatasetResultRank
Fine-grained Image ClassificationCUB200 2011 (test)
Accuracy87
536
Long-Tailed Image ClassificationImageNet-LT 1.0 (test)
Top-1 Accuracy (Overall)44.7
21
Long-tailed classificationCIFAR-10 ratio 0.005
Accuracy72.9
4
Long-tailed classificationCIFAR-10 ratio 0.01
Accuracy78.5
4
Long-tailed classificationCIFAR-100 ratio 0.005
Accuracy40.9
4
Long-tailed classificationCIFAR-100 ratio 0.01
Accuracy45.3
4
Long-tailed classificationCIFAR-100 ratio 0.02
Accuracy50.4
4
Long-tailed classificationSVHN ratio 0.005
Accuracy (%)42.8
4
Long-tailed classificationSVHN ratio 0.01
Accuracy45.7
4
Long-tailed classificationSVHN ratio 0.02
Accuracy49.8
4
Showing 10 of 14 rows

Other info

Code

Follow for update