Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

In-context Reinforcement Learning with Algorithm Distillation

About

We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data.

Michael Laskin, Luyu Wang, Junhyuk Oh, Emilio Parisotto, Stephen Spencer, Richie Steigerwald, DJ Strouse, Steven Hansen, Angelos Filos, Ethan Brooks, Maxime Gazeau, Himanshu Sahni, Satinder Singh, Volodymyr Mnih• 2022

Related benchmarks

TaskDatasetResultRank
Multi-Armed BanditMulti-Armed Bandit (MAB) Horizon Generalization T=100
Average Regret32.97
7
DarkroomGrid World
Offline Training Time (hour)0.21
6
Continuous ControlHPP-25 complete 1
NAUC1.14
4
Continuous ControlWLP-50-1 (complete)
NAUC1.15
4
Continuous ControlHPP-25-1 complete (test)
Return233.3
4
Continuous ControlWLP-50-1 complete (test)
Return251.5
4
Continuous ControlWLP-100-1 complete (test)
Return248.4
4
In-Context Reinforcement LearningMW 20-1 DR9
NAUC32
4
Offline In-Context Reinforcement LearningDR9-20-1-early (test)
NAUC Score18
4
Offline In-Context Reinforcement LearningDR19-75-1 early (test)
NAUC12
4
Showing 10 of 135 rows
...

Other info

Follow for update