Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

DiffPO: A causal diffusion model for learning distributions of potential outcomes

About

Predicting potential outcomes of interventions from observational data is crucial for decision-making in medicine, but the task is challenging due to the fundamental problem of causal inference. Existing methods are largely limited to point estimates of potential outcomes with no uncertain quantification; thus, the full information about the distributions of potential outcomes is typically ignored. In this paper, we propose a novel causal diffusion model called DiffPO, which is carefully designed for reliable inferences in medicine by learning the distribution of potential outcomes. In our DiffPO, we leverage a tailored conditional denoising diffusion model to learn complex distributions, where we address the selection bias through a novel orthogonal diffusion loss. Another strength of our DiffPO method is that it is highly flexible (e.g., it can also be used to estimate different causal quantities such as CATE). Across a wide range of experiments, we show that our method achieves state-of-the-art performance.

Yuchen Ma, Valentyn Melnychuk, Jonas Schweisthal, Stefan Feuerriegel• 2024

Related benchmarks

TaskDatasetResultRank
Counterfactual Distribution EstimationColored MNIST a=0
Mean Out-Sample W225.48
16
Conditional distribution estimationIHDP100 a=1 (out-sample)
W2 Score0.033
16
Conditional distribution estimationIHDP100 a=0 (out-sample)
W20.05
16
Counterfactual Distribution EstimationColored MNIST a=1
Mean Out-Sample W222.53
16
Counterfactual Distribution EstimationColored MNIST a=2
Mean Out-Sample W224.24
16
Counterfactual Distribution EstimationColored MNIST a=3
Mean Out-Sample W223.13
16
Counterfactual Distribution EstimationColored MNIST a=4
Mean Out-Sample W223.26
16
CATE estimationACIC 77 datasets 2016 (in-sample)
Percentage Best26.73
9
CATE estimationACIC 77 datasets 2016 (out-of-sample)
% Best25.82
9
CATE estimationACIC 2018 (in-sample)
Percent Best31.34
9
Showing 10 of 18 rows

Other info

Code

Follow for update