Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Mask-based Latent Reconstruction for Reinforcement Learning

About

For deep reinforcement learning (RL) from pixels, learning effective state representations is crucial for achieving high performance. However, in practice, limited experience and high-dimensional inputs prevent effective representation learning. To address this, motivated by the success of mask-based modeling in other research fields, we introduce mask-based reconstruction to promote state representation learning in RL. Specifically, we propose a simple yet effective self-supervised method, Mask-based Latent Reconstruction (MLR), to predict complete state representations in the latent space from the observations with spatially and temporally masked pixels. MLR enables better use of context information when learning state representations to make them more informative, which facilitates the training of RL agents. Extensive experiments show that our MLR significantly improves the sample efficiency in RL and outperforms the state-of-the-art sample-efficient RL methods on multiple continuous and discrete control benchmarks. Our code is available at https://github.com/microsoft/Mask-based-Latent-Reconstruction.

Tao Yu, Zhizheng Zhang, Cuiling Lan, Yan Lu, Zhibo Chen• 2022

Related benchmarks

TaskDatasetResultRank
Continuous ControlDMControl 500k
Spin Score973
33
Continuous ControlDMControl 100k
DMControl: Finger Spin Score907
29
Reinforcement LearningAtari 100k
Alien Score990.1
18
Visual Reinforcement LearningDMControl Reacher Easy
Episode Return866
16
Visual Reinforcement LearningDMControl Cheetah Run
Episode Return482
16
Visual Reinforcement LearningDMControl Ball in cup, Catch
Episode Return933
16
Visual Reinforcement LearningDMControl Finger, Spin
Episode Return907
16
Visual Reinforcement LearningDMControl Walker Walk
Episode Return643
16
Visual Reinforcement LearningDMControl Cartpole, Swingup
Episode Return806
16
Autonomous DrivingCARLA (#HW)
Error Rate106
15
Showing 10 of 12 rows

Other info

Code

Follow for update