Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Beyond DAGs: A Latent Partial Causal Model for Multimodal Learning

About

Directed Acyclic Graphs (DAGs) are a standard tool in causal modeling, but their suitability for capturing the complexity of large-scale multimodal data is questionable. In practice, real-world multimodal datasets are often collected from heterogeneous generative processes that do not conform to a single DAG. Instead, they may involve multiple, and even opposing, DAG structures with inverse causal directions. To address this gap, in this work, we first propose a novel latent partial causal model tailored for multimodal data representation learning, featuring two latent coupled variables parts connected by an undirected edge, to represent the transfer of knowledge across modalities. Under specific statistical assumptions, we establish an identifiability result, demonstrating that representations learned by MultiModal Contrastive Learning (MMCL) correspond to the latent coupled variables up to a trivial transformation. This result deepens our understanding of the why MMCL works, highlights its potential for representation disentanglement, and expands the utility of pre-trained models like CLIP. Synthetic experiments confirm the robustness of our findings, even when the assumptions are partially violated. Most importantly, experiments on a pre-trained CLIP model embodies disentangled representations, enabling few-shot learning and improving domain generalization across diverse real-world datasets. Together, these contributions push the boundaries of MMCL, both in theory and in practical applications.

Yuhang Liu, Zhen Zhang, Dong Gong, Erdun Gao, Biwei Huang, Mingming Gong, Anton van den Hengel, Kun Zhang, Javen Qinfeng Shi• 2024

Related benchmarks

TaskDatasetResultRank
Image ClassificationImageNet Domain Generalization (Source: ImageNet, Targets: ImageNetV2, ImageNet-Sketch, ImageNet-A, ImageNet-R) (test)
Accuracy (ImageNetV2)54.81
84
Domain GeneralizationImageNet variants (V2, S, A, R) (test)
ImageNet-V2 Accuracy58.45
54
Image ClassificationImageNet Robustness Generalization Suite Sketch A R V2
Top-1 Acc (V2)30.26
43
Image ClassificationImageNet
Accuracy46.57
27
16-shot Image ClassificationImageNet 1k (source)
Accuracy68.12
12
Domain GeneralizationImageNet Sketch, R, A V2
Accuracy (V2)40.66
12
Image ClassificationImageNet Sketch, R, A V2
Accuracy (V2)48.18
12
Showing 7 of 7 rows

Other info

Follow for update