Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Learning a Diffusion Model Policy from Rewards via Q-Score Matching

About

Diffusion models have become a popular choice for representing actor policies in behavior cloning and offline reinforcement learning. This is due to their natural ability to optimize an expressive class of distributions over a continuous space. However, previous works fail to exploit the score-based structure of diffusion models, and instead utilize a simple behavior cloning term to train the actor, limiting their ability in the actor-critic setting. In this paper, we present a theoretical framework linking the structure of diffusion model policies to a learned Q-function, by linking the structure between the score of the policy to the action gradient of the Q-function. We focus on off-policy reinforcement learning and propose a new policy update method from this theory, which we denote Q-score matching. Notably, this algorithm only needs to differentiate through the denoising model rather than the entire diffusion model evaluation, and converged policies through Q-score matching are implicitly multi-modal and explorative in continuous domains. We conduct experiments in simulated environments to demonstrate the viability of our proposed method and compare to popular baselines. Source code is available from the project website: https://michaelpsenka.io/qsm.

Michael Psenka, Alejandro Escontrela, Pieter Abbeel, Yi Ma• 2023

Related benchmarks

TaskDatasetResultRank
Offline Reinforcement LearningOGBench
AntMaze Giant Navigate13
56
Online Reinforcement LearningOpenAI Gym MuJoCo Normalized v4
Normalized Mean Return55.3
50
Continuous ControlMuJoCo Ant v4
Average Return4.21e+3
46
Continuous ControlMuJoCo Walker2d v4--
39
Continuous ControlMuJoCo HalfCheetah v4
Average Return3.89e+3
36
Reinforcement LearningMuJoCo Half-Cheetah
Average Return12
28
Offline Reinforcement LearningOGBench puzzle-4x4
Success Rate0.00e+0
26
Offline Reinforcement LearningOGBench cube-triple (ct)
Success Rate3
25
Reinforcement LearningSwimmer
Average Returns108
24
Reinforcement LearningMuJoCo Hopper
Average Return1.59e+3
24
Showing 10 of 38 rows

Other info

Follow for update