Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Aligning Text-to-Image Diffusion Models with Reward Backpropagation

About

Text-to-image diffusion models have recently emerged at the forefront of image generation, powered by very large-scale unsupervised or weakly supervised text-to-image training datasets. Due to their unsupervised training, controlling their behavior in downstream tasks, such as maximizing human-perceived image quality, image-text alignment, or ethical image generation, is difficult. Recent works finetune diffusion models to downstream reward functions using vanilla reinforcement learning, notorious for the high variance of the gradient estimators. In this paper, we propose AlignProp, a method that aligns diffusion models to downstream reward functions using end-to-end backpropagation of the reward gradient through the denoising process. While naive implementation of such backpropagation would require prohibitive memory resources for storing the partial derivatives of modern text-to-image models, AlignProp finetunes low-rank adapter weight modules and uses gradient checkpointing, to render its memory usage viable. We test AlignProp in finetuning diffusion models to various objectives, such as image-text semantic alignment, aesthetics, compressibility and controllability of the number of objects present, as well as their combinations. We show AlignProp achieves higher rewards in fewer training steps than alternatives, while being conceptually simpler, making it a straightforward choice for optimizing diffusion models for differentiable reward functions of interest. Code and Visualization results are available at https://align-prop.github.io/.

Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki• 2023

Related benchmarks

TaskDatasetResultRank
Text-to-motion generationHumanML3D (test)
FID0.266
331
Text-to-Image GenerationGenEval 1.0 (test)
Overall Score39
63
Text-to-Image GenerationPick-a-Pic (val)
PickScore20.56
20
Text-to-Image GenerationPick-a-Pic (500), HPSv2 (500), and PartiPrompts (1000) (test)
PickScore20.56
10
Text-to-Image SynthesisGenEval SD V1.5
Overall Score42
9
Showing 5 of 5 rows

Other info

Follow for update