Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

AdaFocal: Calibration-aware Adaptive Focal Loss

About

Much recent work has been devoted to the problem of ensuring that a neural network's confidence scores match the true probability of being correct, i.e. the calibration problem. Of note, it was found that training with focal loss leads to better calibration than cross-entropy while achieving similar level of accuracy \cite{mukhoti2020}. This success stems from focal loss regularizing the entropy of the model's prediction (controlled by the parameter $\gamma$), thereby reining in the model's overconfidence. Further improvement is expected if $\gamma$ is selected independently for each training sample (Sample-Dependent Focal Loss (FLSD-53) \cite{mukhoti2020}). However, FLSD-53 is based on heuristics and does not generalize well. In this paper, we propose a calibration-aware adaptive focal loss called AdaFocal that utilizes the calibration properties of focal (and inverse-focal) loss and adaptively modifies $\gamma_t$ for different groups of samples based on $\gamma_{t-1}$ from the previous step and the knowledge of model's under/over-confidence on the validation set. We evaluate AdaFocal on various image recognition and one NLP task, covering a wide variety of network architectures, to confirm the improvement in calibration while achieving similar levels of accuracy. Additionally, we show that models trained with AdaFocal achieve a significant boost in out-of-distribution detection.

Arindam Ghosh, Thomas Schaaf, Matthew R. Gormley• 2022

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-10 (test)--
410
Image ClassificationTiny ImageNet (test)--
362
Image ClassificationCIFAR-100 (test)--
175
Image Classification CalibrationCIFAR100
Classwise ECE0.2
90
Image Classification CalibrationCIFAR10
Classwise ECE0.33
84
Model CalibrationCIFAR-100
ECE1.26
81
Model CalibrationCIFAR-10
ECE56
68
Classwise CalibrationCIFAR-10-LT
Average Classwise ECE1.31
56
Image ClassificationCIFAR-10 LT-100 (test)
Error Rate28.79
40
Image ClassificationCIFAR-100 (test)
Accuracy77.15
38
Showing 10 of 21 rows

Other info

Follow for update