Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs

About

Reward models trained on human preference data have been proven to effectively align Large Language Models (LLMs) with human intent within the framework of reinforcement learning from human feedback (RLHF). However, current reward models have limited generalization capabilities to unseen prompts and responses, which can lead to an unexpected phenomenon known as reward over-optimization, resulting in a decline in actual performance due to excessive optimization of rewards. While previous research has advocated for constraining policy optimization, our study introduces a novel approach to enhance the reward model's generalization ability against distribution shifts by regularizing the hidden states. Specifically, we retain the base model's language model head and incorporate a suite of text-generation losses to preserve the hidden states' text-generation capabilities, while concurrently learning a reward head behind the same hidden states. Our experimental results demonstrate that the introduced regularization technique markedly improves the accuracy of learned reward models across a variety of out-of-distribution (OOD) tasks and effectively alleviates the over-optimization issue in RLHF, offering a more reliable and robust preference learning paradigm.

Rui Yang, Ruomeng Ding, Yong Lin, Huan Zhang, Tong Zhang• 2024

Related benchmarks

TaskDatasetResultRank
Reward ModelingRewardBench
Avg Score87
118
Reward ModelingUnified Feedback (UF)
Accuracy78.9
40
Role-playing Reward ModelingRoleRM-Bench
Average Score56.5
22
Reward ModelingRewardBench unified-feedback (test)
Average Score79.5
20
Reward ModelingJudgeBench
Knowledge57.1
13
Reward ModelingPMDC Maximum Discrepancy samples
Rank3
10
Reward ModelingSafeRLHF Standard
Accuracy89.8
9
Reward ModelingHHH-Alignment Standard
Accuracy88.2
9
Reward ModelingHHH-Alignment Reversed
Accuracy11.8
9
Reward ModelingSafeRLHF Reversed
Accuracy10.2
9
Showing 10 of 16 rows

Other info

Code

Follow for update