Distributionally Robust Token Optimization in RLHF
About
Large Language Models (LLMs) tend to respond correctly to prompts that align to the data they were trained and fine-tuned on. Yet, small shifts in wording, format, or language can trigger surprisingly large failures, especially on multi-step reasoning problems. To address this problem, we propose a Distributionally Robust Token Optimization (DRTO) approach, which combines token-level Reinforcement Learning from Human Feedback (RLHF) with Distributionally Robust Optimization (DRO). DRTO bounds worst case token-wise rewards by constructing an f-divergence ambiguity set over a loss minibatch, leading to a theoretical robustness. Empirically, DRTO enhances consistency under distribution shifts in mathematical reasoning benchmarks, achieving 9.17\% improvement on GSM8K and 2.49% improvement on MathQA.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | MathQA | Accuracy45.2 | 305 | |
| Math Reasoning | GSM8K | Accuracy79.9 | 187 | |
| Mathematical Reasoning | GSM-PLUS | Accuracy57.2 | 66 | |
| Math Reasoning | GSM CoT | Accuracy (GSM CoT)83.2 | 7 | |
| Math Reasoning | GSM DE | Accuracy66 | 7 | |
| Mathematical Reasoning | GSM8K ZH (test) | Accuracy (ZH)58 | 7 | |
| Mathematical Reasoning | GSM8K DE (test) | Accuracy66 | 7 | |
| Mathematical Reasoning | GSM8K ES (test) | Accuracy72 | 7 | |
| Mathematical Reasoning | GSM8K FR (test) | Accuracy64 | 7 |