Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Fast Training of Diffusion Models with Masked Transformers

About

We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion models without sacrificing the generative performance.

Hongkai Zheng, Weili Nie, Arash Vahdat, Anima Anandkumar• 2023

Related benchmarks

TaskDatasetResultRank
Class-conditional Image GenerationImageNet 256x256
Inception Score (IS)276.6
441
Image GenerationImageNet 256x256 (val)
FID2.28
307
Class-conditional Image GenerationImageNet 256x256 (train)
IS276.6
305
Image GenerationImageNet 256x256
FID2.28
243
Class-conditional Image GenerationImageNet 256x256 (train val)
FID2.28
178
Image ReconstructionImageNet 256x256
rFID0.61
93
Image GenerationImageNet 256x256 (train)
FID2.28
91
Class-conditional Image GenerationImageNet 512x512
FID2.5
72
Image GenerationImageNet 512x512 (test)
FID2.5
57
Class-conditional Image GenerationImageNet-1K 256x256 (test)
FID5.69
50
Showing 10 of 13 rows

Other info

Follow for update