Foundation Models for Causal Inference via Prior-Data Fitted Networks
About
Prior-data fitted networks (PFNs) have recently been proposed as a promising way to train tabular foundation models. PFNs are transformers that are pre-trained on synthetic data generated from a prespecified prior distribution and that enable Bayesian inference through in-context learning. In this paper, we introduce CausalFM, a comprehensive framework for training PFN-based foundation models in various causal inference settings. First, we formalize the construction of Bayesian priors for causal inference based on structural causal models (SCMs) in a principled way and derive necessary criteria for the validity of such priors. Building on this, we propose a novel family of prior distributions using causality-inspired Bayesian neural networks that enable CausalFM to perform Bayesian causal inference in various settings, including for back-door, front-door, and instrumental variable adjustment. Finally, we instantiate CausalFM and explicitly train models to perform in-context learning in these settings. We show that CausalFM achieves competitive in-context learning performance even when compared to baselines that are specifically trained for the task at hand. In sum, our framework can be used as a general recipe to train foundation models for various causal inference settings. In contrast to the current state-of-the-art in causal inference, CausalFM offers a novel paradigm with the potential to fundamentally change how practitioners perform causal inference in medicine, economics, and other disciplines.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| CATE estimation | Binary IV | PEHE0.422 | 9 | |
| Conditional Average Treatment Effect (CATE) Estimation | 10 Synthetic Datasets | PEHE0.515 | 9 | |
| Conditional Average Treatment Effect (CATE) Estimation | Jobs semi-synthetic | PEHE0.478 | 9 | |
| Conditional Average Treatment Effect estimation | Standard CATE setting average per dataset | Latency (s)0.49 | 9 | |
| IV Estimation | IV setting | Time (s)0.472 | 9 | |
| CATE estimation | IV Continuous | PEHE0.579 | 9 | |
| CATE estimation | Front-door adjustment setting | PEHE0.847 | 5 |