Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search

About

Recent methodologies in LLM self-training mostly rely on LLM generating responses and filtering those with correct output answers as training data. This approach often yields a low-quality fine-tuning training set (e.g., incorrect plans or intermediate reasoning). In this paper, we develop a reinforced self-training approach, called ReST-MCTS*, based on integrating process reward guidance with tree search MCTS* for collecting higher-quality reasoning traces as well as per-step value to train policy and reward models. ReST-MCTS* circumvents the per-step manual annotation typically used to train process rewards by tree-search-based reinforcement learning: Given oracle final correct answers, ReST-MCTS* is able to infer the correct process rewards by estimating the probability this step can help lead to the correct answer. These inferred rewards serve dual purposes: they act as value targets for further refining the process reward model and also facilitate the selection of high-quality traces for policy model self-training. We first show that the tree-search policy in ReST-MCTS* achieves higher accuracy compared with prior LLM reasoning baselines such as Best-of-N and Tree-of-Thought, within the same search budget. We then show that by using traces searched by this tree-search policy as training data, we can continuously enhance the three language models for multiple iterations, and outperform other self-training algorithms such as ReST$^\text{EM}$ and Self-Rewarding LM. We release all code at https://github.com/THUDM/ReST-MCTS.

Dan Zhang, Sining Zhoubian, Ziniu Hu, Yisong Yue, Yuxiao Dong, Jie Tang• 2024

Related benchmarks

TaskDatasetResultRank
Mathematical ReasoningAIME 25
Accuracy91.3
201
CodingHumanEval
Pass@170.4
168
Multimodal ReasoningMMMU-Pro
Accuracy81.3
146
ReasoningARC Challenge
Accuracy84.2
100
Math ReasoningAMC
Accuracy67.4
95
Math ReasoningMATH500
Accuracy93.2
83
Math ReasoningJEEBench
Accuracy70.3
82
CodingMBPP
Pass@1 Accuracy76.8
78
Mathematical ReasoningMATH500
Accuracy93.2
76
Math ReasoningOlympiadBench
Accuracy84.7
76
Showing 10 of 30 rows

Other info

Follow for update