Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Waypoint Transformer: Reinforcement Learning via Supervised Learning with Intermediate Targets

About

Despite the recent advancements in offline reinforcement learning via supervised learning (RvS) and the success of the decision transformer (DT) architecture in various domains, DTs have fallen short in several challenging benchmarks. The root cause of this underperformance lies in their inability to seamlessly connect segments of suboptimal trajectories. To overcome this limitation, we present a novel approach to enhance RvS methods by integrating intermediate targets. We introduce the Waypoint Transformer (WT), using an architecture that builds upon the DT framework and conditioned on automatically-generated waypoints. The results show a significant increase in the final return compared to existing RvS methods, with performance on par or greater than existing state-of-the-art temporal difference learning-based methods. Additionally, the performance and stability improvements are largest in the most challenging environments and data configurations, including AntMaze Large Play/Diverse and Kitchen Mixed/Partial.

Anirudhan Badrinath, Yannis Flet-Berliac, Allen Nie, Emma Brunskill• 2023

Related benchmarks

TaskDatasetResultRank
LocomotionD4RL Halfcheetah medium
Normalized Score43
44
Offline multitask Reinforcement LearningFranka Kitchen kitchen-mixed
Average Episodic Return70.9
23
Robot ManipulationD4RL FrankaKitchen kitchen-complete v0
Normalized Score49.2
14
Offline Reinforcement Learning (Robotic Manipulation)Franka Kitchen partial D4RL (pt)
Expert Normalized Return63.8
10
LocomotionD4RL Gym-MuJoCo hopper-medium v2
Normalized Score63.1
10
Offline Reinforcement Learning (Navigation)AntMaze umaze-diverse D4RL (ud)
Expert Normalized Return71.5
10
LocomotionD4RL Gym-MuJoCo walker2d-medium v2
Normalized Score74.8
10
NavigationD4RL AntMaze umaze v2
Normalized Score64.1
10
Offline Reinforcement Learning (Navigation)AntMaze umaze D4RL
Expert Normalized Return64.9
10
Showing 9 of 9 rows

Other info

Code

Follow for update