Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Causal Transformer for Estimating Counterfactual Outcomes

About

Estimating counterfactual outcomes over time from observational data is relevant for many applications (e.g., personalized medicine). Yet, state-of-the-art methods build upon simple long short-term memory (LSTM) networks, thus rendering inferences for complex, long-range dependencies challenging. In this paper, we develop a novel Causal Transformer for estimating counterfactual outcomes over time. Our model is specifically designed to capture complex, long-range dependencies among time-varying confounders. For this, we combine three transformer subnetworks with separate inputs for time-varying covariates, previous treatments, and previous outcomes into a joint network with in-between cross-attentions. We further develop a custom, end-to-end training procedure for our Causal Transformer. Specifically, we propose a novel counterfactual domain confusion loss to address confounding bias: it aims to learn adversarial balanced representations, so that they are predictive of the next outcome but non-predictive of the current treatment assignment. We evaluate our Causal Transformer based on synthetic and real-world datasets, where it achieves superior performance over current baselines. To the best of our knowledge, this is the first work proposing transformer-based architecture for estimating counterfactual outcomes from longitudinal data.

Valentyn Melnychuk, Dennis Frauen, Stefan Feuerriegel• 2022

Related benchmarks

TaskDatasetResultRank
Factual outcome predictionMIMIC-III extract
RMSE9.05
105
Counterfactual Outcome EstimationTumor Growth tau=2 synthetic (test)
RMSE3.44
77
Counterfactual outcome predictionMIMIC III semi-synthetic (800/200/200)
RMSE0.2
57
Counterfactual outcome predictionMIMIC-III semi-synthetic (N=1000) (test)
RMSE0.33
35
Counterfactual outcome predictionMIMIC-III semi-synthetic (N=2000) (test)
RMSE0.31
35
Counterfactual outcome predictionMIMIC-III semi-synthetic (N=3000) (test)
RMSE0.32
35
Counterfactual Response EstimationSynthetic Cancer Simulation Dataset training sequence length 60 (test)
NRMSE0.84
20
System Dynamics PredictionCOVID-19 (test)
TMSE0.309
9
System Dynamics PredictionLung Cancer with Chemo. (test)
TMSE0.348
9
System Dynamics PredictionLung Cancer (with Chemo. & Radio.) (test)
TMSE0.216
9
Showing 10 of 17 rows

Other info

Follow for update