In-context Reinforcement Learning with Algorithm Distillation
About
We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Multi-Armed Bandit | Multi-Armed Bandit (MAB) Horizon Generalization T=100 | Average Regret32.97 | 7 | |
| Darkroom | Grid World | Offline Training Time (hour)0.21 | 6 | |
| Continuous Control | HPP-25 complete 1 | NAUC1.14 | 4 | |
| Continuous Control | WLP-50-1 (complete) | NAUC1.15 | 4 | |
| Continuous Control | HPP-25-1 complete (test) | Return233.3 | 4 | |
| Continuous Control | WLP-50-1 complete (test) | Return251.5 | 4 | |
| Continuous Control | WLP-100-1 complete (test) | Return248.4 | 4 | |
| In-Context Reinforcement Learning | MW 20-1 DR9 | NAUC32 | 4 | |
| Offline In-Context Reinforcement Learning | DR9-20-1-early (test) | NAUC Score18 | 4 | |
| Offline In-Context Reinforcement Learning | DR19-75-1 early (test) | NAUC12 | 4 |