Amortized Factor Inference Networks for Posterior Inference
About
Amortized inference promises fast test-time Bayesian inference, but existing methods are inherently tied to fixed models. Extending amortization to unseen models typically requires retraining or costly test-time finetuning. In this paper, we ask: is it possible to build a single inference network capable of generalizing across varying priors, likelihoods, and dimensionality? We introduce Amortized Factor Inference Networks (AFINs), a family of encode-merge-decode inference networks built on dimension-independent modules that map a model specification and its observations to the parameters of a variational posterior. Experimentally, a single trained AFIN achieves posterior accuracy comparable to NUTS and several variational inference methods, while requiring 2 to 4 orders of magnitude less test-time compute. Code is available at https://github.com/joohwanko/AFINs.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Posterior Covariance Estimation | Synthetic OOD N | Covariance Frobenius Error0.0011 | 4 | |
| Posterior Covariance Estimation | Synthetic OOD d, N | Covariance Frobenius Error0.0175 | 4 | |
| Posterior Inference | Synthetic Extrapolation Stress Test N=512 (OOD N) | Sliced Wasserstein-2 Distance0.0013 | 4 | |
| Posterior Inference | Synthetic Extrapolation Stress Test OOD d=32, N=512 | SW-2 Distance0.0039 | 4 | |
| Posterior Mean Estimation | Synthetic extrapolation tasks OOD N | Mean L2 Error0.0019 | 4 | |
| Posterior Mean Estimation | Synthetic extrapolation tasks OOD d, N split | Mean L2 Error0.0098 | 4 | |
| Posterior Covariance Estimation | Synthetic OOD d | Covariance Frobenius Error0.234 | 4 | |
| Posterior Inference | Synthetic Extrapolation Stress Test OOD d=32 | SW2 Distance0.024 | 4 | |
| Posterior Mean Estimation | Synthetic extrapolation tasks (OOD d) | Mean L2 Error0.068 | 4 | |
| Binary Classification | OpenML 16 binary v2 (70/30 train test) | Accuracy85.7 | 2 |