Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Causal Foundation Models with Continuous Treatments

About

Causal inference, estimating causal effects from observational data, is a fundamental tool in many disciplines. Of particular importance across a variety of domains is the continuous treatment setting, where the variable of intervention has a continuous range. This setting is far less explored and represents a substantial shift from the binary treatment setting, with models needing to represent effects across a continuum of treatment values. In this paper, we present the first causal foundation model for the continuous treatment setting. Our model meta-learns the ability to predict causal effects across a wide variety of unseen tasks without additional training or fine-tuning. First, we design a novel prior over data-generating processes with continuous treatment variables in order to generate a rich causal training corpus. We then train a transformer to reconstruct individual treatment-response curves given only observational data, leveraging in-context learning to amortize expensive Bayesian posterior inference. Our model achieves state-of-the-art performance on individual treatment-response curve reconstruction tasks compared to causal models which are trained specifically for those tasks.

Christopher Stith, Medha Barath, Vahid Balazadeh, Jesse C. Cresswell, Rahul G. Krishnan• 2026

Related benchmarks

TaskDatasetResultRank
Dosage Policy Estimation (DPE)News (test)
Mean DPE3.71
12
Dosage Policy Estimation (DPE)Aggregate Debt, Warfarin, TCGA, News, NewsHet
Average Rank6
12
Dosage Policy Estimation (DPE)Warfarin (test)
Mean DPE2.91
12
Dosage Policy Estimation (DPE)NewsHet (test)
Mean DPE1.64
12
Dosage Policy Estimation (DPE)Debt (test)
Mean DPE0.29
12
Dosage Policy Estimation (DPE)TCGA (test)
Mean DPE38.9
11
Treatment-response curve estimationDebt (test)
Mean MISE2.26
9
Treatment-response curve estimationWarfarin (test)
Mean MISE40.4
9
Treatment-response curve estimationTCGA (test)
Average Rank2.8
9
Treatment-response curve estimationMVICU (test)
Mean MISE1.45
9
Showing 10 of 12 rows

Other info

Follow for update