MemReward: Graph-Based Experience Memory for LLM Reward Prediction with Limited Labels
About
Recent advances in large language models (LLMs) have been driven by reinforcement-learning-based post-training, which requires multiple rollouts with rewards. However, obtaining ground truth labels for the calculation of rewards on a scale often requires expensive human labeling or time-consuming verification procedures. For instance, evaluating mathematical proofs demands expert review, and open-ended question answering lacks definitive ground truth. When ground truth labels are scarce, the effectiveness of reinforcement learning fine-tuning can be constrained. We introduce MemReward, a graph-based experience memory framework: an initial LLM policy generates rollouts for each query, each comprising a thinking process and a final answer, and these rollouts are stored as experience memory. Queries, thinking processes, and answers form nodes in a heterogeneous graph with similarity and structural edges; a GNN trained on labeled rollouts propagates rewards to unlabeled rollouts during online optimization. Experiments on Qwen2.5-3B and 1.5B in mathematics, question answering, and code generation demonstrate that MemReward, with only 20% labels, achieves 97.3% of Oracle performance on 3B and 96.6% on 1.5B, surpassing Oracle in out-of-domain tasks. Performance scales smoothly with label budget, reaching 99.4% of Oracle at 70% labels.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Question Answering | ARC Challenge | Accuracy80.44 | 906 | |
| Question Answering | GPQA | Accuracy30 | 33 | |
| Question Answering | OBQA | Accuracy81.78 | 14 | |
| Question Answering | MMLU | Accuracy0.72 | 8 | |
| Mathematical Reasoning | MATH | Exact Match Accuracy61.11 | 6 | |
| Reward Prediction | NuminaMath (out-of-domain) | Accuracy42.22 | 6 | |
| Reward Prediction | SIQA (out-of-domain) | Accuracy76.89 | 6 | |
| Reward Prediction | Out-of-Domain Task Suite NuminaMath, SIQA, PIQA | Average Score66.96 | 6 | |
| Mathematical Reasoning | GSM8K | Exact Match Accuracy92.89 | 6 | |
| Mathematical Reasoning | GSM-sym | Exact Match86.44 | 6 |