Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Characterizing Out-of-Distribution Error via Optimal Transport

About

Out-of-distribution (OOD) data poses serious challenges in deployed machine learning models, so methods of predicting a model's performance on OOD data without labels are important for machine learning safety. While a number of methods have been proposed by prior work, they often underestimate the actual error, sometimes by a large margin, which greatly impacts their applicability to real tasks. In this work, we identify pseudo-label shift, or the difference between the predicted and true OOD label distributions, as a key indicator to this underestimation. Based on this observation, we introduce a novel method for estimating model performance by leveraging optimal transport theory, Confidence Optimal Transport (COT), and show that it provably provides more robust error estimates in the presence of pseudo-label shift. Additionally, we introduce an empirically-motivated variant of COT, Confidence Optimal Transport with Thresholding (COTT), which applies thresholding to the individual transport costs and further improves the accuracy of COT's error estimates. We evaluate COT and COTT on a variety of standard benchmarks that induce various types of distribution shift -- synthetic, novel subpopulation, and natural -- and show that our approaches significantly outperform existing state-of-the-art methods with an up to 3x lower prediction error.

Yuzhe Lu, Yilong Qin, Runtian Zhai, Andrew Shen, Ketong Chen, Zhenlin Wang, Soheil Kolouri, Simon Stepputtis, Joseph Campbell, Katia Sycara• 2023

Related benchmarks

TaskDatasetResultRank
Accuracy EstimationPACS
R20.891
50
Accuracy EstimationNonliving-26 Subpopulation Shift
R20.982
36
Unsupervised Accuracy EstimationOffice-Home
R^20.863
36
Unsupervised Accuracy EstimationDomainNet
R^20.903
36
Accuracy EstimationLiving-17 Subpopulation Shift
R20.972
36
Unsupervised Accuracy EstimationRR1-WILDS
R-squared0.969
36
Accuracy EstimationEntity-13 Subpopulation Shift
R20.96
36
Accuracy EstimationEntity-30 Subpopulation Shift
R20.971
36
Accuracy EstimationCIFAR-10
MAE0.355
27
Accuracy EstimationCIFAR-100
MAE0.355
27
Showing 10 of 18 rows

Other info

Follow for update