Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

DiffStitch: Boosting Offline Reinforcement Learning with Diffusion-based Trajectory Stitching

About

In offline reinforcement learning (RL), the performance of the learned policy highly depends on the quality of offline datasets. However, in many cases, the offline dataset contains very limited optimal trajectories, which poses a challenge for offline RL algorithms as agents must acquire the ability to transit to high-reward regions. To address this issue, we introduce Diffusion-based Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories. DiffStitch effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories to address the challenges faced by offline RL algorithms. Empirical experiments conducted on D4RL datasets demonstrate the effectiveness of DiffStitch across RL methodologies. Notably, DiffStitch demonstrates substantial enhancements in the performance of one-step methods (IQL), imitation learning methods (TD3+BC), and trajectory optimization methods (DT).

Guanghe Li, Yixiang Shan, Zhengbang Zhu, Ting Long, Weinan Zhang• 2024

Related benchmarks

TaskDatasetResultRank
LocomotionD4RL walker2d-medium-expert
Normalized Score109.4
90
walker2d locomotionD4RL walker2d medium-replay
Normalized Score82.4
78
Offline Reinforcement LearningD4RL antmaze-umaze (diverse)
Normalized Score42.6
74
LocomotionD4RL Halfcheetah medium
Normalized Score48
70
LocomotionD4RL Walker2d medium
Normalized Score82.4
70
LocomotionD4RL HalfCheetah Medium-Replay--
68
Offline Reinforcement LearningD4RL AntMaze--
65
Offline Reinforcement LearningD4RL antmaze-large (diverse)
Normalized Score2.6
47
Offline Reinforcement LearningD4RL Maze2d-large
Normalized Performance65.2
31
LocomotionD4RL hopper-medium-expert
Normalized Score (100k Steps)108.8
28
Showing 10 of 20 rows

Other info

Follow for update