Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning

About

Guided sampling is a vital approach for applying diffusion models in real-world tasks that embeds human-defined guidance during the sampling procedure. This paper considers a general setting where the guidance is defined by an (unnormalized) energy function. The main challenge for this setting is that the intermediate guidance during the diffusion sampling procedure, which is jointly defined by the sampling distribution and the energy function, is unknown and is hard to estimate. To address this challenge, we propose an exact formulation of the intermediate guidance as well as a novel training objective named contrastive energy prediction (CEP) to learn the exact guidance. Our method is guaranteed to converge to the exact guidance under unlimited model capacity and data samples, while previous methods can not. We demonstrate the effectiveness of our method by applying it to offline reinforcement learning (RL). Extensive experiments on D4RL benchmarks demonstrate that our method outperforms existing state-of-the-art algorithms. We also provide some examples of applying CEP for image synthesis to demonstrate the scalability of CEP on high-dimensional data.

Cheng Lu, Huayu Chen, Jianfei Chen, Hang Su, Chongxuan Li, Jun Zhu• 2023

Related benchmarks

TaskDatasetResultRank
Offline Reinforcement LearningD4RL halfcheetah-medium-expert--
117
Offline Reinforcement LearningD4RL hopper-medium-expert--
115
Offline Reinforcement LearningD4RL Hopper medium
Reward98
35
Offline Reinforcement LearningD4RL hopper medium-replay
Reward96.9
30
Offline Reinforcement LearningD4RL Halfcheetah medium
Reward54.1
28
Offline Reinforcement LearningD4RL HalfCheetah Medium-Replay
Reward47.6
17
Offline Reinforcement LearningD4RL v2 (various)
Average Score86.6
17
Offline Reinforcement LearningD4RL v2
Score (HalfCheetah-M)54.1
9
Offline Reinforcement LearningD4RL walker2d-medium-expert
Average Reward110.7
9
Showing 9 of 9 rows

Other info

Follow for update