Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

ODICE: Revealing the Mystery of Distribution Correction Estimation via Orthogonal-gradient Update

About

In this study, we investigate the DIstribution Correction Estimation (DICE) methods, an important line of work in offline reinforcement learning (RL) and imitation learning (IL). DICE-based methods impose state-action-level behavior constraint, which is an ideal choice for offline learning. However, they typically perform much worse than current state-of-the-art (SOTA) methods that solely use action-level behavior constraint. After revisiting DICE-based methods, we find there exist two gradient terms when learning the value function using true-gradient update: forward gradient (taken on the current state) and backward gradient (taken on the next state). Using forward gradient bears a large similarity to many offline RL methods, and thus can be regarded as applying action-level constraint. However, directly adding the backward gradient may degenerate or cancel out its effect if these two gradients have conflicting directions. To resolve this issue, we propose a simple yet effective modification that projects the backward gradient onto the normal plane of the forward gradient, resulting in an orthogonal-gradient update, a new learning rule for DICE-based methods. We conduct thorough theoretical analyses and find that the projected backward gradient brings state-level behavior regularization, which reveals the mystery of DICE-based methods: the value learning objective does try to impose state-action-level constraint, but needs to be used in a corrected way. Through toy examples and extensive experiments on complex offline RL and IL tasks, we demonstrate that DICE-based methods using orthogonal-gradient updates (O-DICE) achieve SOTA performance and great robustness.

Liyuan Mao, Haoran Xu, Weinan Zhang, Xianyuan Zhan• 2024

Related benchmarks

TaskDatasetResultRank
Offline Reinforcement Learninghopper medium
Normalized Score86.1
52
Offline Reinforcement Learningwalker2d medium
Normalized Score84.9
51
Offline Reinforcement Learningwalker2d medium-replay
Normalized Score83.6
50
Offline Reinforcement Learninghopper medium-replay
Normalized Score99.9
44
Offline Reinforcement Learninghalfcheetah medium
Normalized Score47.4
43
Offline Reinforcement Learninghalfcheetah medium-replay
Normalized Score44
43
Offline Reinforcement LearningWalker2d medium-expert
Normalized Score110.8
31
Offline Reinforcement LearningHopper medium-expert
Normalized Score110.8
24
Offline Reinforcement LearningHalfcheetah medium-expert
Normalized Score93.2
15
Offline Reinforcement LearningD4RL v2
Score (HalfCheetah-M)47.4
9
Showing 10 of 10 rows

Other info

Follow for update