Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Learning Goal-Conditioned Representations for Language Reward Models

About

Techniques that learn improved representations via offline data or self-supervised objectives have shown impressive results in traditional reinforcement learning (RL). Nevertheless, it is unclear how improved representation learning can benefit reinforcement learning from human feedback (RLHF) on language models (LMs). In this work, we propose training reward models (RMs) in a contrastive, $\textit{goal-conditioned}$ fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories. This objective significantly improves RM performance by up to 0.09 AUROC across challenging benchmarks, such as MATH and GSM8k. These findings extend to general alignment as well -- on the Helpful-Harmless dataset, we observe $2.3\%$ increase in accuracy. Beyond improving reward model performance, we show this way of training RM representations enables improved $\textit{steerability}$ because it allows us to evaluate the likelihood of an action achieving a particular goal-state (e.g., whether a solution is correct or helpful). Leveraging this insight, we find that we can filter up to $55\%$ of generated tokens during majority voting by discarding trajectories likely to end up in an "incorrect" state, which leads to significant cost savings. We additionally find that these representations can perform fine-grained control by conditioning on desired future goal-states. For example, we show that steering a Llama 3 model towards helpful generations with our approach improves helpfulness by $9.6\%$ over a supervised-fine-tuning trained baseline. Similarly, steering the model towards complex generations improves complexity by $21.6\%$ over the baseline. Overall, we find that training RMs in this contrastive, goal-conditioned fashion significantly improves performance and enables model steerability.

Vaskar Nath, Dylan Slack, Jeff Da, Yuntao Ma, Hugh Zhang, Spencer Whitehead, Sean Hendryx• 2024

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningSVAMP out-of-domain (test)
Accuracy92.9
50
Mathematical ReasoningASDiv Out of Distribution
Top-1 Accuracy (maj@1)88.4
35
Mathematical ReasoningGSM8K In-Distribution (test)
Accuracy89.4
5
Mathematical ReasoningMATH In-Distribution (test)
Final Answer Accuracy55.6
5
Mathematical ReasoningAlgebra222 Out-of-Distribution (test)
Final Answer Accuracy95.1
5
Mathematical ReasoningGSM-Hard Out-of-Distribution (test)
Final Answer Accuracy70.1
5
Mathematical ReasoningMAWPS Out-of-Distribution (test)
Accuracy95.9
5
Mathematical ReasoningMathematical Reasoning (in-distribution)
GSM8K Score80.3
4
Open-ended DialogueOpen-Ended Dialogue (in-distribution)
Helpful Score67.5
4
Mathematical ReasoningMathematical Reasoning (OOD)
Algebra222 Accuracy81.6
4
Showing 10 of 11 rows

Other info

Follow for update