Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models
About
We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at https://github.com/mit-han-lab/efficientvit.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Class-conditional Image Generation | ImageNet 256x256 | -- | 967 | |
| Image Generation | ImageNet 256x256 | IS75.73 | 517 | |
| Text-to-Image Generation | GenEval | Overall Score83 | 218 | |
| Image Reconstruction | ImageNet 256x256 | rFID0.26 | 202 | |
| Image Reconstruction | ImageNet (val) | rFID0.22 | 143 | |
| Conditional Image Generation | ImageNet 512x512 (val) | gFID2.25 | 92 | |
| Image Generation | ImageNet 512x512 | IS187.7 | 83 | |
| Text-to-Image Generation | DPG-Bench | Average Score87.65 | 77 | |
| Class-conditional Image Generation | ImageNet 512x512 (val test) | FID1.72 | 40 | |
| Image Reconstruction | ImageNet 256p | PSNR24.82 | 38 |