Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Analyzing and Improving the Training Dynamics of Diffusion Models

About

Diffusion models currently dominate the field of data-driven image synthesis with their unparalleled scaling to large datasets. In this paper, we identify and rectify several causes for uneven and ineffective training in the popular ADM diffusion model architecture, without altering its high-level structure. Observing uncontrolled magnitude changes and imbalances in both the network activations and weights over the course of training, we redesign the network layers to preserve activation, weight, and update magnitudes on expectation. We find that systematic application of this philosophy eliminates the observed drifts and imbalances, resulting in considerably better networks at equal computational complexity. Our modifications improve the previous record FID of 2.41 in ImageNet-512 synthesis to 1.81, achieved using fast deterministic sampling. As an independent contribution, we present a method for setting the exponential moving average (EMA) parameters post-hoc, i.e., after completing the training run. This allows precise tuning of EMA length without the cost of performing several training runs, and reveals its surprising interactions with network architecture, training time, and guidance.

Tero Karras, Miika Aittala, Jaakko Lehtinen, Janne Hellsten, Timo Aila, Samuli Laine• 2023

Related benchmarks

TaskDatasetResultRank
Image GenerationImageNet 512x512 (val)
FID-50K1.81
184
Image GenerationImageNet 64x64 resolution (test)
FID1.33
150
Class-conditional Image GenerationImageNet 64x64 (test)
FID1.33
86
Image GenerationImageNet 64x64 (train val)
FID1.01
83
Class-conditional Image GenerationImageNet 512x512
FID1.81
72
Class-conditional Image GenerationImageNet 512x512 (val)
FID (Val)1.81
69
Class-conditional Image GenerationImageNet 512x512 (train)
FID1.25
52
Conditional Image GenerationImageNet 64x64 (val)
FID1.48
48
Class-conditional Image GenerationImageNet 512x512 (val test)
FID1.25
40
Class-conditional Image GenerationImageNet-1K 512x512 (val)
FID1.91
33
Showing 10 of 20 rows

Other info

Follow for update