In-context Reinforcement Learning with Algorithm Distillation
About
We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Darkroom | Grid World | Offline Training Time (hour)0.21 | 6 | |
| Dark Key-to-Door | Grid World | Offline Training Time (hour)0.65 | 3 | |
| Darkroom Hard | Grid World | Offline Training Time (hour)0.22 | 3 | |
| HalfCheetah | D4RL | Training Time (hour)28.56 | 3 | |
| Hopper | D4RL | Offline Training Time (hour)18.15 | 3 | |
| Large Dark Key-to-Door | Large Grid World | Offline Training Time (hour)6.87 | 3 | |
| Large Darkroom | Large Grid World | Offline Training Time (hour)3.52 | 3 | |
| Large Darkroom Dynamic | Large Grid World | Offline Training Time (hour)2.71 | 3 | |
| Large Darkroom Hard | Large Grid World | Offline Training Time (hour)4.26 | 3 | |
| Walker2d | D4RL | Offline Training Time26.25 | 3 |