Efficient Online Reinforcement Learning for Diffusion Policy
About
Diffusion policies have achieved superior performance in imitation learning and offline reinforcement learning (RL) due to their rich expressiveness. However, the conventional diffusion training procedure requires samples from target distribution, which is impossible in online RL since we cannot sample from the optimal policy. Backpropagating policy gradient through the diffusion process incurs huge computational costs and instability, thus being expensive and not scalable. To enable efficient training of diffusion policies in online RL, we generalize the conventional denoising score matching by reweighting the loss function. The resulting Reweighted Score Matching (RSM) preserves the optimal solution and low computational cost of denoising score matching, while eliminating the need to sample from the target distribution and allowing learning to optimize value functions. We introduce two tractable reweighted loss functions to solve two commonly used policy optimization problems, policy mirror descent and max-entropy policy, resulting in two practical algorithms named Diffusion Policy Mirror Descent (DPMD) and Soft Diffusion Actor-Critic (SDAC). We conducted comprehensive comparisons on MuJoCo benchmarks. The empirical results show that the proposed algorithms outperform recent diffusion-policy online RLs on most tasks, and the DPMD improves more than 120% over soft actor-critic on Humanoid and Ant.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Online Reinforcement Learning | OpenAI Gym MuJoCo Normalized v4 | Normalized Mean Return84.7 | 50 | |
| Continuous Control | MuJoCo Ant v4 | Average Return5.93e+3 | 46 | |
| Continuous Control | MuJoCo Walker2d v4 | -- | 39 | |
| Continuous Control | MuJoCo HalfCheetah v4 | Average Return1.29e+4 | 36 | |
| Continuous Control | MuJoCo Swimmer v4 | Total Reward80.2 | 19 | |
| Continuous Control | MuJoCo HalfCheetah v5 | Max Return1.07e+4 | 13 | |
| Continuous Control | MuJoCo Walker2d v5 | Max Average Return4.91e+3 | 13 | |
| Continuous Control | MuJoCo Humanoid v5 | Maximum Average Return5.10e+3 | 13 | |
| Locomotion | Humanoid-Bench Stand (test) | Return7.7 | 11 | |
| Locomotion | HalfCheetah v4 | Mean Episode Return1.48e+4 | 10 |