World Models via Policy-Guided Trajectory Diffusion
About
World models are a powerful tool for developing intelligent agents. By predicting the outcome of a sequence of actions, world models enable policies to be optimised via on-policy reinforcement learning (RL) using synthetic data, i.e. in "in imagination". Existing world models are autoregressive in that they interleave predicting the next state with sampling the next action from the policy. Prediction error inevitably compounds as the trajectory length grows. In this work, we propose a novel world modelling approach that is not autoregressive and generates entire on-policy trajectories in a single pass through a diffusion model. Our approach, Policy-Guided Trajectory Diffusion (PolyGRAD), leverages a denoising model in addition to the gradient of the action distribution of the policy to diffuse a trajectory of initially random states and actions into an on-policy synthetic trajectory. We analyse the connections between PolyGRAD, score-based generative models, and classifier-guided diffusion models. Our results demonstrate that PolyGRAD outperforms state-of-the-art baselines in terms of trajectory prediction error for short trajectories, with the exception of autoregressive diffusion. For short trajectories, PolyGRAD obtains similar errors to autoregressive diffusion, but with lower computational requirements. For long trajectories, PolyGRAD obtains comparable performance to baselines. Our experiments demonstrate that PolyGRAD enables performant policies to be trained via on-policy RL in imagination for MuJoCo continuous control domains. Thus, PolyGRAD introduces a new paradigm for accurate on-policy world modelling without autoregressive sampling.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Continuous Control | MuJoCo HalfCheetah | Average Reward4.19e+3 | 25 | |
| Reinforcement Learning | Halfcheetah | Average Return2.56e+3 | 22 | |
| Continuous Control | MuJoCo Reacher | Average Reward4.48 | 18 | |
| Continuous Control | MuJoCo Hopper | Maximum Average Return3.35e+3 | 13 | |
| Continuous Control | MuJoCo Walker2d | Max Return3.78e+3 | 13 | |
| Reinforcement Learning | Humanoid-ET | Average Return1.03e+3 | 12 | |
| Reinforcement Learning | Ant-ET | Average Return433.6 | 12 | |
| Reinforcement Learning | Walker-ET | Average Return268.4 | 12 | |
| Reinforcement Learning | reacher | Average Return-20.7 | 12 | |
| Continuous Control | OpenAI Gym MuJoCo Pendulum POMDP (test) | Average Return174 | 8 |