Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

A Geometric Perspective towards Neural Calibration via Sensitivity Decomposition

About

It is well known that vision classification models suffer from poor calibration in the face of data distribution shifts. In this paper, we take a geometric approach to this problem. We propose Geometric Sensitivity Decomposition (GSD) which decomposes the norm of a sample feature embedding and the angular similarity to a target classifier into an instance-dependent and an instance-independent component. The instance-dependent component captures the sensitive information about changes in the input while the instance-independent component represents the insensitive information serving solely to minimize the loss on the training dataset. Inspired by the decomposition, we analytically derive a simple extension to current softmax-linear models, which learns to disentangle the two components during training. On several common vision models, the disentangled model outperforms other calibration methods on standard calibration metrics in the face of out-of-distribution (OOD) data and corruption with significantly less complexity. Specifically, we surpass the current state of the art by 30.8% relative improvement on corrupted CIFAR100 in Expected Calibration Error. Code available at https://github.com/GT-RIPL/Geometric-Sensitivity-Decomposition.git.

Junjiao Tian, Dylan Yung, Yen-Chang Hsu, Zsolt Kira• 2021

Related benchmarks

TaskDatasetResultRank
Out-of-Distribution DetectionCIFAR-10 (ID) vs SVHN (OOD) (test)
AUROC99.05
79
OOD DetectionCIFAR-100 IND SVHN OOD
AUROC (%)94.46
74
Image ClassificationCIFAR100 Clean (test)
Accuracy83.09
38
CalibrationCIFAR10 Noise levels 1-5 (val)
NLL0.531
20
Image ClassificationCIFAR100 Noise (val)
Brier Score0.003
20
Image ClassificationCIFAR10 Noise (val)
Brier Score0.022
20
CalibrationCIFAR100 Noise levels 1-5 (val)
NLL1.786
20
Image ClassificationCIFAR10 Corrupted
Accuracy77.9
20
OOD DetectionCIFAR-100 vs SVHN (test)--
18
Image ClassificationCIFAR100 Corrupted (test)
Accuracy54.1
16
Showing 10 of 15 rows

Other info

Code

Follow for update