Causal Foundation Models with Continuous Treatments
About
Causal inference, estimating causal effects from observational data, is a fundamental tool in many disciplines. Of particular importance across a variety of domains is the continuous treatment setting, where the variable of intervention has a continuous range. This setting is far less explored and represents a substantial shift from the binary treatment setting, with models needing to represent effects across a continuum of treatment values. In this paper, we present the first causal foundation model for the continuous treatment setting. Our model meta-learns the ability to predict causal effects across a wide variety of unseen tasks without additional training or fine-tuning. First, we design a novel prior over data-generating processes with continuous treatment variables in order to generate a rich causal training corpus. We then train a transformer to reconstruct individual treatment-response curves given only observational data, leveraging in-context learning to amortize expensive Bayesian posterior inference. Our model achieves state-of-the-art performance on individual treatment-response curve reconstruction tasks compared to causal models which are trained specifically for those tasks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Dosage Policy Estimation (DPE) | News (test) | Mean DPE3.71 | 12 | |
| Dosage Policy Estimation (DPE) | Aggregate Debt, Warfarin, TCGA, News, NewsHet | Average Rank6 | 12 | |
| Dosage Policy Estimation (DPE) | Warfarin (test) | Mean DPE2.91 | 12 | |
| Dosage Policy Estimation (DPE) | NewsHet (test) | Mean DPE1.64 | 12 | |
| Dosage Policy Estimation (DPE) | Debt (test) | Mean DPE0.29 | 12 | |
| Dosage Policy Estimation (DPE) | TCGA (test) | Mean DPE38.9 | 11 | |
| Treatment-response curve estimation | Debt (test) | Mean MISE2.26 | 9 | |
| Treatment-response curve estimation | Warfarin (test) | Mean MISE40.4 | 9 | |
| Treatment-response curve estimation | TCGA (test) | Average Rank2.8 | 9 | |
| Treatment-response curve estimation | MVICU (test) | Mean MISE1.45 | 9 |