Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Fast Training of Diffusion Models with Masked Transformers

About

We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion models without sacrificing the generative performance.

Hongkai Zheng, Weili Nie, Arash Vahdat, Anima Anandkumar• 2023

Related benchmarks

TaskDatasetResultRank
Class-conditional Image GenerationImageNet 256x256
Inception Score (IS)276.6
815
Image GenerationImageNet 256x256
IS276.6
359
Class-conditional Image GenerationImageNet 256x256 (train)
IS276.6
345
Image GenerationImageNet 256x256 (val)
FID2.28
340
Image GenerationImageNet 512x512 (val)
FID-50K2.5
219
Class-conditional Image GenerationImageNet 256x256 (train val)
FID2.28
178
Image GenerationImageNet 256x256 (train)
FID2.28
164
Image ReconstructionImageNet 256x256
rFID0.61
150
Class-conditional Image GenerationImageNet 512x512
FID2.5
111
Class-conditional generationImageNet 256 x 256 1k (val)
IS276.6
102
Showing 10 of 20 rows

Other info

Follow for update