Inference Time Policy Optimization for Offline RL with Differentiable World Models
About
Offline Reinforcement Learning (RL) learns optimal policies from fixed datasets, training a policy once and deploying it at inference time without further refinement. Inspired by model predictive control (MPC), we introduce an inference time adaptation framework that utilizes a pretrained policy along with a learned world model. While existing world model and diffusion-planning methods use learned dynamics to generate imagined trajectories during training, or to sample candidate plans at inference time, they do not use inference-time information to *optimize* the policy parameters on the fly. In contrast, our design is a Differentiable World Model (DWM) pipeline that enables end-to-end gradient computation through imagined rollouts for inference time policy optimization (ITPO). We evaluate our algorithm on D4RL continuous-control benchmarks (MuJoCo locomotion tasks and AntMaze), and show that exploiting inference-time information to optimize the policy parameters yields consistent gains over strong offline RL baselines. Inference-time adaptation, however, is expensive: rollout generation and backpropagation dominate per-step compute. We study this tradeoff explicitly, showing that a suitable tilted version of one-step MeanFlow sampler recovers much of the gains at a fraction of the cost.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline Reinforcement Learning | D4RL AntMaze | AntMaze Umaze Return99 | 65 | |
| Offline Reinforcement Learning | D4RL Gym | Return (Hopper, Random)8.56 | 16 | |
| Offline Reinforcement Learning | D4RL MuJoCo | HalfCheetah (m)70.05 | 13 |