Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Inducing Neural Collapse in Imbalanced Learning: Do We Really Need a Learnable Classifier at the End of Deep Neural Network?

About

Modern deep neural networks for classification usually jointly learn a backbone for representation and a linear classifier to output the logit of each class. A recent study has shown a phenomenon called neural collapse that the within-class means of features and the classifier vectors converge to the vertices of a simplex equiangular tight frame (ETF) at the terminal phase of training on a balanced dataset. Since the ETF geometric structure maximally separates the pair-wise angles of all classes in the classifier, it is natural to raise the question, why do we spend an effort to learn a classifier when we know its optimal geometric structure? In this paper, we study the potential of learning a neural network for classification with the classifier randomly initialized as an ETF and fixed during training. Our analytical work based on the layer-peeled model indicates that the feature learning with a fixed ETF classifier naturally leads to the neural collapse state even when the dataset is imbalanced among classes. We further show that in this case the cross entropy (CE) loss is not necessary and can be replaced by a simple squared loss that shares the same global optimality but enjoys a better convergence property. Our experimental results show that our method is able to bring significant improvements with faster convergence on multiple imbalanced datasets.

Yibo Yang, Shixiang Chen, Xiangtai Li, Liang Xie, Zhouchen Lin, Dacheng Tao• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-10
Accuracy94.4
875
Image ClassificationFood-101--
570
Fine-grained Image ClassificationCUB200 2011 (test)
Accuracy87
567
Image ClassificationCIFAR-100
Accuracy72.1
357
Long-Tailed Image ClassificationImageNet-LT (test)
Top-1 Acc (Overall)44.7
246
Image ClassificationOxford Flowers 102--
234
Image ClassificationImageNet-100
Accuracy84.5
163
Image ClassificationStanford Cars
Top-1 Accuracy9.8
104
Image ClassificationDTD (Describable Textures Dataset)--
80
Image ClassificationCIFAR-10-LT (IF 50)
Top-1 Accuracy81
75
Showing 10 of 35 rows

Other info

Code

Follow for update