Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Faithful Heteroscedastic Regression with Neural Networks

About

Heteroscedastic regression models a Gaussian variable's mean and variance as a function of covariates. Parametric methods that employ neural networks for these parameter maps can capture complex relationships in the data. Yet, optimizing network parameters via log likelihood gradients can yield suboptimal mean and uncalibrated variance estimates. Current solutions side-step this optimization problem with surrogate objectives or Bayesian treatments. Instead, we make two simple modifications to optimization. Notably, their combination produces a heteroscedastic model with mean estimates that are provably as accurate as those from its homoscedastic counterpart (i.e.~fitting the mean under squared error loss). For a wide variety of network and task complexities, we find that mean estimates from existing heteroscedastic solutions can be significantly less accurate than those from an equivalently expressive mean-only model. Our approach provably retains the accuracy of an equally flexible mean-only model while also offering best-in-class variance calibration. Lastly, we show how to leverage our method to recover the underlying heteroscedastic noise variance.

Andrew Stirn, Hans-Hermann Wessels, Megan Schertzer, Laura Pereira, Neville E. Sanjana, David A. Knowles• 2022

Related benchmarks

TaskDatasetResultRank
Molecular property predictionQM9 (test)--
229
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target epsilon_LUMO
FPR9567.6
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target alpha
FPR9562.4
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target U0
FPR9581.8
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target U
FPR9578.4
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target H
FPR@9578.4
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target G
FPR9578.4
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target Delta_epsilon
FPR9594.4
6
Out-of-Distribution DetectionQM9 (ID) Alchemy (OOD) Target ZPVE
FPR@950.875
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target cv
FPR@9574.4
6
Showing 10 of 11 rows

Other info

Follow for update