Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

FSP-Laplace: Function-Space Priors for the Laplace Approximation in Bayesian Deep Learning

About

Laplace approximations are popular techniques for endowing deep networks with epistemic uncertainty estimates as they can be applied without altering the predictions of the trained network, and they scale to large models and datasets. While the choice of prior strongly affects the resulting posterior distribution, computational tractability and lack of interpretability of the weight space typically limit the Laplace approximation to isotropic Gaussian priors, which are known to cause pathological behavior as depth increases. As a remedy, we directly place a prior on function space. More precisely, since Lebesgue densities do not exist on infinite-dimensional function spaces, we recast training as finding the so-called weak mode of the posterior measure under a Gaussian process (GP) prior restricted to the space of functions representable by the neural network. Through the GP prior, one can express structured and interpretable inductive biases, such as regularity or periodicity, directly in function space, while still exploiting the implicit inductive biases that allow deep networks to generalize. After model linearization, the training objective induces a negative log-posterior density to which we apply a Laplace approximation, leveraging highly scalable methods from matrix-free linear algebra. Our method provides improved results where prior knowledge is abundant (as is the case in many scientific inference tasks). At the same time, it stays competitive for black-box supervised learning problems, where neural networks typically excel.

Tristan Cinquin, Marvin Pf\"ortner, Vincent Fortuin, Philipp Hennig, Robert Bamler• 2024

Related benchmarks

TaskDatasetResultRank
Image ClassificationFashionMNIST (test)
Accuracy90.9
218
RegressionBoston UCI (test)--
26
RegressionUCI KIN8NM (test)--
25
Out-of-Distribution DetectionFashionMNIST (test)--
14
Image ClassificationMNIST (test)
Log-likelihood-0.037
7
Out-of-Distribution DetectionMNIST Out-of-Distribution (test)
OOD Accuracy97.7
7
RegressionDenmark (UCI) (test)
Expected Log-Likelihood-0.364
6
Out-of-Distribution DetectionConcrete UCI (test)
OOD Detection Accuracy91.4
5
Out-of-Distribution DetectionEnergy UCI (test)
OOD Detection Accuracy100
5
Out-of-Distribution DetectionNaval UCI (test)
OOD Accuracy100
5
Showing 10 of 20 rows

Other info

Follow for update