REPA-E: Unlocking VAE for End-to-End Tuning with Latent Diffusion Transformers
About
In this paper we tackle a fundamental question: "Can we train latent diffusion models together with the variational auto-encoder (VAE) tokenizer in an end-to-end manner?" Traditional deep-learning wisdom dictates that end-to-end training is often preferable when possible. However, for latent diffusion transformers, it is observed that end-to-end training both VAE and diffusion-model using standard diffusion-loss is ineffective, even causing a degradation in final performance. We show that while diffusion loss is ineffective, end-to-end training can be unlocked through the representation-alignment (REPA) loss -- allowing both VAE and diffusion model to be jointly tuned during the training process. Despite its simplicity, the proposed training recipe (REPA-E) shows remarkable performance; speeding up diffusion model training by over 17x and 45x over REPA and vanilla training recipes, respectively. Interestingly, we observe that end-to-end tuning with REPA-E also improves the VAE itself; leading to improved latent space structure and downstream generation performance. In terms of final performance, our approach sets a new state-of-the-art; achieving FID of 1.12 and 1.69 with and without classifier-free guidance on ImageNet 256 x 256. Code is available at https://end2end-diffusion.github.io.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Class-conditional Image Generation | ImageNet 256x256 | Inception Score (IS)302.9 | 441 | |
| Image Generation | ImageNet 256x256 (val) | FID1.12 | 307 | |
| Class-conditional Image Generation | ImageNet 256x256 (train) | IS314.9 | 305 | |
| Image Generation | ImageNet 256x256 | -- | 243 | |
| Class-conditional Image Generation | ImageNet 256x256 (train val) | FID1.26 | 178 | |
| Class-conditional Image Generation | ImageNet 256x256 (test) | FID1.15 | 167 | |
| Image Reconstruction | ImageNet 256x256 | rFID0.28 | 93 | |
| Conditional Image Generation | ImageNet 256px 2012 (val) | FID1.12 | 50 | |
| Image Generation | ImageNet 256x256 (test val) | -- | 35 |