Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Diffusion Based Causal Representation Learning

About

Causal reasoning can be considered a cornerstone of intelligent systems. Having access to an underlying causal graph comes with the promise of cause-effect estimation and the identification of efficient and safe interventions. However, learning causal representations remains a major challenge, due to the complexity of many real-world systems. Previous works on causal representation learning have mostly focused on Variational Auto-Encoders (VAE). These methods only provide representations from a point estimate, and they are unsuitable to handle high dimensions. To overcome these problems, we proposed a new Diffusion-based Causal Representation Learning (DCRL) algorithm. This algorithm uses diffusion-based representations for causal discovery. DCRL offers access to infinite dimensional latent codes, which encode different levels of information in the latent code. In a first proof of principle, we investigate the use of DCRL for causal representation learning. We further demonstrate experimentally that this approach performs comparably well in identifying the causal structure and causal variables.

Amir Mohammad Karimi Mamaghan, Andrea Dittadi, Stefan Bauer, Karl Henrik Johansson, Francesco Quinzan• 2023

Related benchmarks

TaskDatasetResultRank
DAG Structure Recoverynon-linear-1 5000 samples
SHD2.9
48
Causal Discoverynon-linear-2 d=10, 5000 samples (test)
SHD2.2
12
Causal Discoverynon-linear-2 (d=20, 5000 samples) (test)
SHD7.1
12
Causal Discoverynon-linear-2 d=50, 5000 samples (test)
Structural Hamming Distance15.1
12
Causal Structure LearningLinear Synthetic Data d=10 5000 samples
SHD1.8
12
Causal Structure LearningLinear Synthetic Data d=20, 5000 samples
SHD3.1
12
Causal Discoverynon-linear-2 d=100, 5000 samples (test)
Structural Hamming Distance (SHD)59.5
12
Causal Structure LearningLinear Synthetic Data d=50, 5000 samples
SHD18.7
12
Causal Structure LearningLinear Synthetic Data d=100, 5000 samples
SHD53.3
12
Showing 9 of 9 rows

Other info

Follow for update