Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Tractable Function-Space Variational Inference in Bayesian Neural Networks

About

Reliable predictive uncertainty estimation plays an important role in enabling the deployment of neural networks to safety-critical settings. A popular approach for estimating the predictive uncertainty of neural networks is to define a prior distribution over the network parameters, infer an approximate posterior distribution, and use it to make stochastic predictions. However, explicit inference over neural network parameters makes it difficult to incorporate meaningful prior information about the data-generating process into the model. In this paper, we pursue an alternative approach. Recognizing that the primary object of interest in most settings is the distribution over functions induced by the posterior distribution over neural network parameters, we frame Bayesian inference in neural networks explicitly as inferring a posterior distribution over functions and propose a scalable function-space variational inference method that allows incorporating prior information and results in reliable predictive uncertainty estimates. We show that the proposed method leads to state-of-the-art uncertainty estimation and predictive performance on a range of prediction tasks and demonstrate that it performs well on a challenging safety-critical medical diagnosis task in which reliable uncertainty estimation is essential.

Tim G. J. Rudner, Zonghao Chen, Yee Whye Teh, Yarin Gal• 2023

Related benchmarks

TaskDatasetResultRank
Image ClassificationCIFAR-10
Accuracy95.19
507
Image ClassificationFashionMNIST (test)
Accuracy94.44
218
Out-of-Distribution DetectionSVHN
AUROC99.19
62
Out-of-Distribution DetectionFashionMNIST (ID) vs MNIST (OoD)
AUROC0.998
61
Diabetic Retinopathy DiagnosisAPTOS 2019 (Population Shift)
AUC94.6
36
Diabetic Retinopathy DiagnosisEyePACS In-Domain
AUC95.2
36
RegressionBoston UCI (test)--
26
RegressionUCI KIN8NM (test)--
25
Image ClassificationCIFAR10 Corrupted
Accuracy81.35
20
Out-of-Distribution DetectionFashionMNIST (test)--
14
Showing 10 of 33 rows

Other info

Code

Follow for update