Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Amortized Probabilistic Conditioning for Optimization, Simulation and Inference

About

Amortized meta-learning methods based on pre-training have propelled fields like natural language processing and vision. Transformer-based neural processes and their variants are leading models for probabilistic meta-learning with a tractable objective. Often trained on synthetic data, these models implicitly capture essential latent information in the data-generation process. However, existing methods do not allow users to flexibly inject (condition on) and extract (predict) this probabilistic latent information at runtime, which is key to many tasks. We introduce the Amortized Conditioning Engine (ACE), a new transformer-based meta-learning model that explicitly represents latent variables of interest. ACE affords conditioning on both observed data and interpretable latent variables, the inclusion of priors at runtime, and outputs predictive distributions for discrete and continuous data and latents. We show ACE's modeling flexibility and performance in diverse tasks such as image completion and classification, Bayesian optimization, and simulation-based inference.

Paul E. Chang, Nasrulloh Loka, Daolang Huang, Ulpu Remes, Samuel Kaski, Luigi Acerbi• 2024

Related benchmarks

TaskDatasetResultRank
Data PredictionOUP
RMSE0.22
9
Data PredictionTurin
RMSE0.16
9
Approximate Bayesian InferenceInverse-gamma prior and normal-variance likelihood Wide Meta-prior
KL Divergence0.0048
7
Approximate Bayesian InferenceInverse-gamma prior and normal-variance likelihood Narrow Meta-prior
KL Divergence0.0094
7
Posterior InferenceTwo Moons q_strong(theta)
RMSE0.09
6
Approximate Bayesian InferenceExperiment 5D input Gaussian Process 4.2.1
Expected NLL (Hyper)0.35
6
CalibrationGP Hyperposterior Experiment 4.2.1 (test)
MMD0.0041
6
CalibrationQuantum System Parameter Inference Experiment 4.2.2 (test)
MMD0.0067
5
CalibrationGP PPD Experiment 4.2.1 (test)
MMD0.0035
4
Posterior InferenceTwo Moons q_mild(theta)
RMSE0.34
3
Showing 10 of 26 rows

Other info

Follow for update