Elastic Decision Transformer
About
This paper introduces Elastic Decision Transformer (EDT), a significant advancement over the existing Decision Transformer (DT) and its variants. Although DT purports to generate an optimal trajectory, empirical evidence suggests it struggles with trajectory stitching, a process involving the generation of an optimal or near-optimal trajectory from the best parts of a set of sub-optimal trajectories. The proposed EDT differentiates itself by facilitating trajectory stitching during action inference at test time, achieved by adjusting the history length maintained in DT. Further, the EDT optimizes the trajectory by retaining a longer history when the previous trajectory is optimal and a shorter one when it is sub-optimal, enabling it to "stitch" with a more optimal trajectory. Extensive experimentation demonstrates EDT's ability to bridge the performance gap between DT-based and Q Learning-based approaches. In particular, the EDT outperforms Q Learning-based methods in a multi-task regime on the D4RL locomotion benchmark and Atari games. Videos are available at: https://kristery.github.io/edt/
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline Reinforcement Learning | D4RL halfcheetah-medium-expert | Normalized Score89.7 | 117 | |
| Offline Reinforcement Learning | D4RL hopper-medium-expert | Normalized Score104.7 | 115 | |
| Offline Reinforcement Learning | D4RL walker2d-medium-expert | Normalized Score107.8 | 86 | |
| Offline Reinforcement Learning | D4RL Medium-Replay Hopper | Normalized Score89.4 | 72 | |
| Offline Reinforcement Learning | D4RL Medium HalfCheetah | Normalized Score43 | 59 | |
| Offline Reinforcement Learning | D4RL Medium-Replay HalfCheetah | Normalized Score37.8 | 59 | |
| Offline Reinforcement Learning | D4RL Medium Walker2d | Normalized Score75.8 | 58 | |
| Offline Reinforcement Learning | D4RL walker2d medium-replay | Normalized Score73.3 | 45 | |
| Offline Reinforcement Learning | D4RL Hopper Medium v2 | Normalized Score70.4 | 26 | |
| Offline Reinforcement Learning | MuJoCo hopper D4RL (medium-replay) | Normalized Return89 | 26 |