Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Diffusion Model with Cross Attention as an Inductive Bias for Disentanglement

About

Disentangled representation learning strives to extract the intrinsic factors within observed data. Factorizing these representations in an unsupervised manner is notably challenging and usually requires tailored loss functions or specific structural designs. In this paper, we introduce a new perspective and framework, demonstrating that diffusion models with cross-attention can serve as a powerful inductive bias to facilitate the learning of disentangled representations. We propose to encode an image to a set of concept tokens and treat them as the condition of the latent diffusion for image reconstruction, where cross-attention over the concept tokens is used to bridge the interaction between the encoder and diffusion. Without any additional regularization, this framework achieves superior disentanglement performance on the benchmark datasets, surpassing all previous methods with intricate designs. We have conducted comprehensive ablation studies and visualization analysis, shedding light on the functioning of this model. This is the first work to reveal the potent disentanglement capability of diffusion models with cross-attention, requiring no complex designs. We anticipate that our findings will inspire more investigation on exploring diffusion for disentangled representation learning towards more sophisticated data analysis and understanding.

Tao Yang, Cuiling Lan, Yan Lu, Nanning zheng• 2024

Related benchmarks

TaskDatasetResultRank
Image GenerationCelebA 64 x 64 (test)
FID14.8
203
Image GenerationCelebA (test)
FID14.8
49
Disentangled Representation LearningCars3D
FactorVAE0.773
35
Disentangled Representation LearningShapes3D
FactorVAE Score0.999
18
Disentangled Representation LearningMPI3D
FactorVAE Score0.872
18
DisentanglementShapes3D
D0.969
18
DisentanglementMPI3D
D0.685
18
DisentanglementShapes3D (test)
FactorVAE0.989
13
DisentanglementCars3D
FVAE0.773
10
Disentangled Representation LearningCelebA 64x64 (test)
TAD0.638
10
Showing 10 of 12 rows

Other info

Follow for update