Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning
About
Offline reinforcement learning (RL), which aims to learn an optimal policy using a previously collected static dataset, is an important paradigm of RL. Standard RL methods often perform poorly in this regime due to the function approximation errors on out-of-distribution actions. While a variety of regularization methods have been proposed to mitigate this issue, they are often constrained by policy classes with limited expressiveness that can lead to highly suboptimal solutions. In this paper, we propose representing the policy as a diffusion model, a recent class of highly-expressive deep generative models. We introduce Diffusion Q-learning (Diffusion-QL) that utilizes a conditional diffusion model to represent the policy. In our approach, we learn an action-value function and we add a term maximizing action-values into the training loss of the conditional diffusion model, which results in a loss that seeks optimal actions that are near the behavior policy. We show the expressiveness of the diffusion model-based policy, and the coupling of the behavior cloning and policy improvement under the diffusion model both contribute to the outstanding performance of Diffusion-QL. We illustrate the superiority of our method compared to prior works in a simple 2D bandit example with a multimodal behavior policy. We then show that our method can achieve state-of-the-art performance on the majority of the D4RL benchmark tasks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline Reinforcement Learning | D4RL halfcheetah-medium-expert | Normalized Score96.8 | 155 | |
| Offline Reinforcement Learning | D4RL hopper-medium-expert | Normalized Score111.1 | 153 | |
| Offline Reinforcement Learning | D4RL walker2d-medium-expert | Normalized Score110.1 | 124 | |
| Offline Reinforcement Learning | D4RL Medium HalfCheetah | Normalized Score69.1 | 97 | |
| Offline Reinforcement Learning | D4RL Medium Walker2d | Normalized Score87 | 96 | |
| hopper locomotion | D4RL hopper medium-replay | Normalized Score101.3 | 66 | |
| Offline Reinforcement Learning | D4RL AntMaze | AntMaze Umaze Return94 | 65 | |
| Offline Reinforcement Learning | D4RL Medium Hopper | Normalized Score90.5 | 64 | |
| walker2d locomotion | D4RL walker2d medium-replay | Normalized Score95.5 | 63 | |
| Locomotion | D4RL walker2d-medium-expert | Normalized Score110.1 | 63 |