Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Return-Aligned Decision Transformer

About

Traditional approaches in offline reinforcement learning aim to learn the optimal policy that maximizes the cumulative reward, also known as return. It is increasingly important to adjust the performance of AI agents to meet human requirements, for example, in applications like video games and education tools. Decision Transformer (DT) optimizes a policy that generates actions conditioned on the target return through supervised learning and includes a mechanism to control the agent's performance using the target return. However, the action generation is hardly influenced by the target return because DT's self-attention allocates scarce attention scores to the return tokens. In this paper, we propose Return-Aligned Decision Transformer (RADT), designed to more effectively align the actual return with the target return. RADT leverages features extracted by paying attention solely to the return, enabling action generation to consistently depend on the target return. Extensive experiments show that RADT significantly reduces the discrepancies between the actual return and the target return compared to DT-based methods. Our code is available at https://github.com/CyberAgentAILab/radt

Tsunehiko Tanaka, Kenshi Abe, Kaito Ariu, Tetsuro Morimura, Edgar Simo-Serra• 2024

Related benchmarks

TaskDatasetResultRank
Offline Reinforcement Learningwalker2d medium-replay
Normalized Score75.9
61
Offline Reinforcement Learninghopper medium-replay
Normalized Score95.7
55
Offline Reinforcement Learninghalfcheetah medium-replay
Normalized Score41.3
54
Offline Reinforcement LearningWalker2d medium-expert
Normalized Score109.7
42
Offline Reinforcement LearningHalfCheetah Vel
Maximum episode return-1.03e+3
40
Offline Reinforcement LearningHopper medium-expert
Normalized Score110.4
35
Offline Reinforcement LearningHalfcheetah medium-expert
Normalized Score93.1
26
Reinforcement LearningHalfCheetah Vel
Average Episode Return-754.7
10
Return Alignmenthopper medium-replay
RMSE (Return Alignment)6.49
6
Return AlignmentHopper medium-expert
RMSE8.18
6
Showing 10 of 17 rows

Other info

Follow for update