Causal Transformer for Estimating Counterfactual Outcomes
About
Estimating counterfactual outcomes over time from observational data is relevant for many applications (e.g., personalized medicine). Yet, state-of-the-art methods build upon simple long short-term memory (LSTM) networks, thus rendering inferences for complex, long-range dependencies challenging. In this paper, we develop a novel Causal Transformer for estimating counterfactual outcomes over time. Our model is specifically designed to capture complex, long-range dependencies among time-varying confounders. For this, we combine three transformer subnetworks with separate inputs for time-varying covariates, previous treatments, and previous outcomes into a joint network with in-between cross-attentions. We further develop a custom, end-to-end training procedure for our Causal Transformer. Specifically, we propose a novel counterfactual domain confusion loss to address confounding bias: it aims to learn adversarial balanced representations, so that they are predictive of the next outcome but non-predictive of the current treatment assignment. We evaluate our Causal Transformer based on synthetic and real-world datasets, where it achieves superior performance over current baselines. To the best of our knowledge, this is the first work proposing transformer-based architecture for estimating counterfactual outcomes from longitudinal data.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Factual outcome prediction | MIMIC-III extract | RMSE9.05 | 105 | |
| Counterfactual Outcome Estimation | Tumor Growth tau=2 synthetic (test) | RMSE3.44 | 77 | |
| Counterfactual outcome prediction | MIMIC III semi-synthetic (800/200/200) | RMSE0.2 | 57 | |
| Counterfactual outcome prediction | MIMIC-III semi-synthetic (N=1000) (test) | RMSE0.33 | 35 | |
| Counterfactual outcome prediction | MIMIC-III semi-synthetic (N=2000) (test) | RMSE0.31 | 35 | |
| Counterfactual outcome prediction | MIMIC-III semi-synthetic (N=3000) (test) | RMSE0.32 | 35 | |
| Counterfactual Response Estimation | Synthetic Cancer Simulation Dataset training sequence length 60 (test) | NRMSE0.84 | 20 | |
| System Dynamics Prediction | COVID-19 (test) | TMSE0.309 | 9 | |
| System Dynamics Prediction | Lung Cancer with Chemo. (test) | TMSE0.348 | 9 | |
| System Dynamics Prediction | Lung Cancer (with Chemo. & Radio.) (test) | TMSE0.216 | 9 |