Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

A Learning Based Hypothesis Test for Harmful Covariate Shift

About

The ability to quickly and accurately identify covariate shift at test time is a critical and often overlooked component of safe machine learning systems deployed in high-risk domains. While methods exist for detecting when predictions should not be made on out-of-distribution test examples, identifying distributional level differences between training and test time can help determine when a model should be removed from the deployment setting and retrained. In this work, we define harmful covariate shift (HCS) as a change in distribution that may weaken the generalization of a predictive model. To detect HCS, we use the discordance between an ensemble of classifiers trained to agree on training data and disagree on test data. We derive a loss function for training this ensemble and show that the disagreement rate and entropy represent powerful discriminative statistics for HCS. Empirically, we demonstrate the ability of our method to detect harmful covariate shift with statistical certainty on a variety of high-dimensional datasets. Across numerous domains and modalities, we show state-of-the-art performance compared to existing methods, particularly when the number of observed test samples is small.

Tom Ginsberg, Zhongyuan Liang, Rahul G. Krishnan• 2022

Related benchmarks

TaskDatasetResultRank
Out-of-Distribution DetectionMNIST (In-distribution) vs Fashion-MNIST (OOD) (test)
AUPR0.8375
36
Out-of-Distribution DetectionCIFAR10 (In-distribution) vs SVHN (OOD) (test)
AUPR90
18
Image ClassificationCIFAR-10 1v6
Error (Actual)9.3
12
Out-of-Distribution DetectionNHANESI
FDP60
12
OOD DetectionDiabetes Retinopathy Detection (DRD) 8 (test)
AUROC0.9074
7
OOD DetectionCIFAR10 (In-dist) vs Fashion-MNIST (OOD) (test)
AUC76.46
7
OOD DetectionMNIST (In-dist) vs OMNIGLOT (OOD) (test)
AUC95.71
7
OOD DetectionMNIST (In-dist) vs SVHN (OOD) (test)
AUC0.7792
7
OOD DetectionCIFAR10 (In-dist) vs OMNIGLOT (OOD) (test)
AUC76.99
7
Harmful shift detectionUCI Heart Disease (UCI-HD) 5-class (test)
AUC-ROC0.995
6
Showing 10 of 14 rows

Other info

Follow for update