A Tractable Inference Perspective of Offline RL
About
A popular paradigm for offline Reinforcement Learning (RL) tasks is to first fit the offline trajectories to a sequence model, and then prompt the model for actions that lead to high expected return. In addition to obtaining accurate sequence models, this paper highlights that tractability, the ability to exactly and efficiently answer various probabilistic queries, plays an important role in offline RL. Specifically, due to the fundamental stochasticity from the offline data-collection policies and the environment dynamics, highly non-trivial conditional/constrained generation is required to elicit rewarding actions. it is still possible to approximate such queries, we observe that such crude estimates significantly undermine the benefits brought by expressive sequence models. To overcome this problem, this paper proposes Trifle (Tractable Inference for Offline RL), which leverages modern Tractable Probabilistic Models (TPMs) to bridge the gap between good sequence models and high expected returns at evaluation time. Empirically, Trifle achieves the most state-of-the-art scores in 9 Gym-MuJoCo benchmarks against strong baselines. Further, owing to its tractability, Trifle significantly outperforms prior approaches in stochastic environments and safe RL tasks (e.g. with action constraints) with minimum algorithmic modifications.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline Reinforcement Learning | D4RL Gym walker2d (medium-replay) | Normalized Return90.6 | 52 | |
| Offline Reinforcement Learning | D4RL Gym halfcheetah-medium | Normalized Return49.5 | 44 | |
| Offline Reinforcement Learning | D4RL MuJoCo Hopper medium standard | Normalized Score67.6 | 36 | |
| Offline Reinforcement Learning | D4RL mujoco-hopper (med-replay) | Normalized Score97.8 | 23 | |
| Offline Reinforcement Learning | D4RL Hopper Med-Expert | Normalized Average Return113 | 21 | |
| Offline Reinforcement Learning | HalfCheetah Medium-Expert Gym-MuJoCo D4RL | Normalized Score95.1 | 18 | |
| Offline Reinforcement Learning | Walker Gym-MuJoCo Medium-Expert D4RL | Normalized Score109.6 | 18 | |
| Offline Reinforcement Learning | Walker Medium Gym-MuJoCo D4RL | Normalized Score84.7 | 16 | |
| Offline Reinforcement Learning | HalfCheetah Gym-MuJoCo Medium-Replay D4RL | Normalized Score48.9 | 16 | |
| Reinforcement Learning | Taxi environment stochastic | Episode Return-57 | 6 |