Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

On the Pitfalls of Heteroscedastic Uncertainty Estimation with Probabilistic Neural Networks

About

Capturing aleatoric uncertainty is a critical part of many machine learning systems. In deep learning, a common approach to this end is to train a neural network to estimate the parameters of a heteroscedastic Gaussian distribution by maximizing the logarithm of the likelihood function under the observed data. In this work, we examine this approach and identify potential hazards associated with the use of log-likelihood in conjunction with gradient-based optimizers. First, we present a synthetic example illustrating how this approach can lead to very poor but stable parameter estimates. Second, we identify the culprit to be the log-likelihood loss, along with certain conditions that exacerbate the issue. Third, we present an alternative formulation, termed $\beta$-NLL, in which each data point's contribution to the loss is weighted by the $\beta$-exponentiated variance estimate. We show that using an appropriate $\beta$ largely mitigates the issue in our illustrative example. Fourth, we evaluate this approach on a range of domains and tasks and show that it achieves considerable improvements and performs more robustly concerning hyperparameters, both in predictive RMSE and log-likelihood criteria.

Maximilian Seitzer, Arash Tavakoli, Dimitrije Antic, Georg Martius• 2022

Related benchmarks

TaskDatasetResultRank
Molecular property predictionQM9 (test)--
229
RegressionUCI ENERGY (test)
Negative Log Likelihood-2.55
47
RegressionUCI KIN8NM (test)
NLL-1.6
25
RegressionUCI housing (test)
RMSE1.65
10
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target alpha
FPR9549.8
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target U
FPR9535.8
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD epsilon_HOMO
FPR9586.5
6
Out-of-Distribution DetectionQM9 (ID) Alchemy (OOD) Target ZPVE
FPR@950.678
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target U0
FPR9545.4
6
Out-of-Distribution DetectionQM9 Alchemy ID OOD Target H
FPR@9540.3
6
Showing 10 of 23 rows

Other info

Follow for update