Robust Task Representations for Offline Meta-Reinforcement Learning via Contrastive Learning
About
We study offline meta-reinforcement learning, a practical reinforcement learning paradigm that learns from offline data to adapt to new tasks. The distribution of offline data is determined jointly by the behavior policy and the task. Existing offline meta-reinforcement learning algorithms cannot distinguish these factors, making task representations unstable to the change of behavior policies. To address this problem, we propose a contrastive learning framework for task representations that are robust to the distribution mismatch of behavior policies in training and test. We design a bi-level encoder structure, use mutual information maximization to formalize task representation learning, derive a contrastive learning objective, and introduce several approaches to approximate the true distribution of negative pairs. Experiments on a variety of offline meta-reinforcement learning benchmarks demonstrate the advantages of our method over prior methods, especially on the generalization to out-of-distribution behavior policies. The code is available at https://github.com/PKU-AI-Edge/CORRO.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Offline Meta-Reinforcement Learning | Point-Robot sampled 10 unseen (test) | Average Return-7.8 | 10 | |
| Offline Meta-Reinforcement Learning | Walker-Rand-Params sampled 10 unseen (test) | Average Return312.5 | 10 | |
| Offline Meta-Reinforcement Learning | Half-Cheetah-Vel sampled 10 unseen (test) | Average Return-65.6 | 10 | |
| Reinforcement Learning | Ant-Dir Random OOD | Average Return0.00e+0 | 8 | |
| Reinforcement Learning | Ant-Dir Random IID | Average Return1 | 8 | |
| Reinforcement Learning | Ant-Dir Medium IID | Average Return8 | 8 | |
| Reinforcement Learning | Ant-Dir Medium OOD | Average Return-7 | 8 | |
| Reinforcement Learning | Ant-Dir Expert IID | Average Return-4 | 8 | |
| Reinforcement Learning | Ant-Dir Expert OOD | Average Return-14 | 8 |