Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

DAG-WGAN: Causal Structure Learning With Wasserstein Generative Adversarial Networks

About

The combinatorial search space presents a significant challenge to learning causality from data. Recently, the problem has been formulated into a continuous optimization framework with an acyclicity constraint, allowing for the exploration of deep generative models to better capture data sample distributions and support the discovery of Directed Acyclic Graphs (DAGs) that faithfully represent the underlying data distribution. However, so far no study has investigated the use of Wasserstein distance for causal structure learning via generative models. This paper proposes a new model named DAG-WGAN, which combines the Wasserstein-based adversarial loss, an auto-encoder architecture together with an acyclicity constraint. DAG-WGAN simultaneously learns causal structures and improves its data generation capability by leveraging the strength from the Wasserstein distance metric. Compared with other models, it scales well and handles both continuous and discrete data. Our experiments have evaluated DAG-WGAN against the state-of-the-art and demonstrated its good performance.

Hristo Petkov, Colin Hanley, Feng Dong• 2022

Related benchmarks

TaskDatasetResultRank
DAG Structure Recoverynon-linear-1 5000 samples
SHD6.4
48
Bayesian network structure discoveryHailfinder
SHD73
39
Causal DiscoveryAlarm
SHD36
14
Causal Discoverynon-linear-2 d=50, 5000 samples (test)
Structural Hamming Distance22.6
12
Causal Discoverynon-linear-2 d=100, 5000 samples (test)
Structural Hamming Distance (SHD)64.2
12
Causal Structure LearningLinear Synthetic Data d=50, 5000 samples
SHD19.6
12
Causal Structure LearningLinear Synthetic Data d=100, 5000 samples
SHD58.6
12
Causal Discoverynon-linear-2 d=10, 5000 samples (test)
SHD6.6
12
Causal Discoverynon-linear-2 (d=20, 5000 samples) (test)
SHD15.2
12
Causal Structure LearningLinear Synthetic Data d=10 5000 samples
SHD5.2
12
Showing 10 of 22 rows

Other info

Follow for update