Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Training Diffusion Models with Reinforcement Learning

About

Diffusion models are a class of flexible generative models trained with an approximation to the log-likelihood objective. However, most use cases of diffusion models are not concerned with likelihoods, but instead with downstream objectives such as human-perceived image quality or drug effectiveness. In this paper, we investigate reinforcement learning methods for directly optimizing diffusion models for such objectives. We describe how posing denoising as a multi-step decision-making problem enables a class of policy gradient algorithms, which we refer to as denoising diffusion policy optimization (DDPO), that are more effective than alternative reward-weighted likelihood approaches. Empirically, DDPO is able to adapt text-to-image diffusion models to objectives that are difficult to express via prompting, such as image compressibility, and those derived from human feedback, such as aesthetic quality. Finally, we show that DDPO can improve prompt-image alignment using feedback from a vision-language model without the need for additional data collection or human annotation. The project's website can be found at http://rl-diffusion.github.io .

Kevin Black, Michael Janner, Yilun Du, Ilya Kostrikov, Sergey Levine• 2023

Related benchmarks

TaskDatasetResultRank
Text-to-Image GenerationGenAI-Bench--
41
EEG-to-Image GenerationSubject-01 EEG (test)
EEG Score0.5154
28
Text-to-Image AlignmentPick-a-Pic v2
Image Reward0.6051
27
Video-to-AudioVGGSound (test)
FD (PaSST)54.81
20
Text-to-Image GenerationT2ICompBench++ (val)
VQAScore69.05
17
Text-to-Image GenerationText-to-Image Preference Evaluation Suite (HPSv2.1, ImageReward, PickScore, Aes.Pred.v2.5, CLIP, Unified Reward) v2.1
HPSv2.10.313
14
Text-to-Image GenerationGenEval
VQAScore72.13
14
Text-to-Image GenerationDataset D1 (test)--
14
Text-to-Image GenerationDataset D2 (test)--
14
Text-to-Image GenerationOut-of-Domain T2I Dataset
Laplacian Variance3.90e+3
13
Showing 10 of 41 rows

Other info

Code

Follow for update