Our new X account is live! Follow @wizwand_team for updates
WorkDL logo mark

Factored Causal Representation Learning for Robust Reward Modeling in RLHF

About

A reliable reward model is essential for aligning large language models with human preferences through reinforcement learning from human feedback. However, standard reward models are susceptible to spurious features that are not causally related to human labels. This can lead to reward hacking, where high predicted reward does not translate into better behavior. In this work, we address this problem from a causal perspective by proposing a factored representation learning framework that decomposes the model's contextual embedding into (1) causal factors that are sufficient for reward prediction and (2) non-causal factors that capture reward-irrelevant attributes such as length or sycophantic bias. The reward head is then constrained to depend only on the causal component. In addition, we introduce an adversarial head trained to predict reward from the non-causal factors, while applying gradient reversal to discourage them from encoding reward-relevant information. Experiments on both mathematical and dialogue tasks demonstrate that our method learns more robust reward models and consistently improves downstream RLHF performance over state-of-the-art baselines. Analyses on length and sycophantic bias further validate the effectiveness of our method in mitigating reward hacking behaviors.

Yupei Yang, Lin Yang, Wanxi Deng, Lin Qu, Fan Feng, Biwei Huang, Shikui Tu, Lei Xu• 2026

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningSVAMP out-of-domain (test)
Accuracy93.9
50
Mathematical ReasoningASDiv Out of Distribution
Top-1 Accuracy (maj@1)89.1
35
Mathematical ReasoningGSM8K In-Distribution (test)
Accuracy91.8
5
Mathematical ReasoningMATH In-Distribution (test)
Final Answer Accuracy56.1
5
Mathematical ReasoningAlgebra222 Out-of-Distribution (test)
Final Answer Accuracy97.3
5
Mathematical ReasoningGSM-Hard Out-of-Distribution (test)
Final Answer Accuracy71
5
Mathematical ReasoningMAWPS Out-of-Distribution (test)
Accuracy96.5
5
Mathematical ReasoningMathematical Reasoning (in-distribution)
GSM8K Score81.7
4
Mathematical ReasoningMathematical Reasoning (OOD)
Algebra222 Accuracy89.9
4
Open-ended DialogueAnthropic-Helpful (ID)
Win Rate0.762
4
Showing 10 of 19 rows

Other info

Follow for update