Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

A Fixed-Point Approach for Causal Generative Modeling

About

We propose a novel formalism for describing Structural Causal Models (SCMs) as fixed-point problems on causally ordered variables, eliminating the need for Directed Acyclic Graphs (DAGs), and establish the weakest known conditions for their unique recovery given the topological ordering (TO). Based on this, we design a two-stage causal generative model that first infers in a zero-shot manner a valid TO from observations, and then learns the generative SCM on the ordered variables. To infer TOs, we propose to amortize the learning of TOs on synthetically generated datasets by sequentially predicting the leaves of graphs seen during training. To learn SCMs, we design a transformer-based architecture that exploits a new attention mechanism enabling the modeling of causal structures, and show that this parameterization is consistent with our formalism. Finally, we conduct an extensive evaluation of each method individually, and show that when combined, our model outperforms various baselines on generated out-of-distribution problems. The code is available on \href{https://github.com/microsoft/causica/tree/main/research_experiments/fip}{Github}.

Meyer Scetbon, Joel Jennings, Agrin Hilmkil, Cheng Zhang, Chao Ma• 2024

Related benchmarks

TaskDatasetResultRank
Noise PredictionAVICI (out-of-distribution)
LIN (RMSE)0.04
32
Sample GenerationAVICI RFF (In-distribution)
RMSE0.13
16
Sample GenerationAVICI RFF (Out-of-distribution)
RMSE0.11
16
Counterfactual GenerationAVICI (test)
LIN RMSE (IN)0.01
16
Interventional GenerationAVICI In-distribution
LIN RMSE0.08
16
Noise PredictionAVICI In-distribution
LIN RMSE0.04
16
Sample GenerationAVICI LIN (In-distribution)
RMSE0.07
16
Sample GenerationAVICI LIN (Out-of-distribution)
RMSE0.08
16
Generating observational dataecoli
MMD (Generated vs Query)0.017
8
Generating observational dataFlow Cytometry Sachs (query)
MMD (Generated Query vs True Query)0.015
4
Showing 10 of 10 rows

Other info

Follow for update