Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Towards mental time travel: a hierarchical memory for reinforcement learning agents

About

Reinforcement learning agents often forget details of the past, especially after delays or distractor tasks. Agents with common memory architectures struggle to recall and integrate across multiple timesteps of a past event, or even to recall the details of a single timestep that is followed by distractor tasks. To address these limitations, we propose a Hierarchical Chunk Attention Memory (HCAM), which helps agents to remember the past in detail. HCAM stores memories by dividing the past into chunks, and recalls by first performing high-level attention over coarse summaries of the chunks, and then performing detailed attention within only the most relevant chunks. An agent with HCAM can therefore "mentally time-travel" -- remember past events in detail without attending to all intervening events. We show that agents with HCAM substantially outperform agents with other memory architectures at tasks requiring long-term recall, retention, or reasoning over memory. These include recalling where an object is hidden in a 3D environment, rapidly learning to navigate efficiently in a new neighborhood, and rapidly learning and retaining new object names. Agents with HCAM can extrapolate to task sequences much longer than they were trained on, and can even generalize zero-shot from a meta-learning setting to maintaining knowledge across episodes. HCAM improves agent sample efficiency, generalization, and generality (by solving tasks that previously required specialized architectures). Our work is a step towards agents that can learn, interact, and adapt in complex and temporally-extended environments.

Andrew Kyle Lampinen, Stephanie C.Y. Chan, Andrea Banino, Felix Hill• 2021

Related benchmarks

TaskDatasetResultRank
BalletBallet 2 dances, delay 16
Accuracy99.8
3
BalletBallet 8 dances, delay 16
Accuracy0.981
3
BalletBallet (8 dances, delay 48)
Accuracy97.2
3
Object PermanenceObject Permanence No delay, varying (train)
Accuracy96.7
3
Object PermanenceObject Permanence (Long delay, varying train)
Accuracy91.7
3
Rapid Word LearningRapid Word Learning (10 distractors)
Accuracy0.93
3
Rapid Word LearningRapid Word Learning (4 episodes, 0 distractor each)
Accuracy82.8
3
Rapid Word LearningRapid Word Learning 2 episodes, 5 distractors each
Accuracy71.1
3
Object PermanenceObject Permanence Long delay, long-only (train)
Accuracy82.9
2
One-Shot StreetLearnStreetLearn One-Shot
Average Reward26.8
2
Showing 10 of 12 rows

Other info

Code

Follow for update