JiuZhang3.0: Efficiently Improving Mathematical Reasoning by Training Small Data Synthesis Models
About
Mathematical reasoning is an important capability of large language models~(LLMs) for real-world applications. To enhance this capability, existing work either collects large-scale math-related texts for pre-training, or relies on stronger LLMs (\eg GPT-4) to synthesize massive math problems. Both types of work generally lead to large costs in training or synthesis. To reduce the cost, based on open-source available texts, we propose an efficient way that trains a small LLM for math problem synthesis, to efficiently generate sufficient high-quality pre-training data. To achieve it, we create a dataset using GPT-4 to distill its data synthesis capability into the small LLM. Concretely, we craft a set of prompts based on human education stages to guide GPT-4, to synthesize problems covering diverse math knowledge and difficulty levels. Besides, we adopt the gradient-based influence estimation method to select the most valuable math-related texts. The both are fed into GPT-4 for creating the knowledge distillation dataset to train the small LLM. We leverage it to synthesize 6 million math problems for pre-training our JiuZhang3.0 model, which only needs to invoke GPT-4 API 9.3k times and pre-train on 4.6B data. Experimental results have shown that JiuZhang3.0 achieves state-of-the-art performance on several mathematical reasoning datasets, under both natural language reasoning and tool manipulation settings. Our code and data will be publicly released in \url{https://github.com/RUCAIBox/JiuZhang3.0}.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K | Accuracy89.8 | 983 | |
| Mathematical Reasoning | MATH | Accuracy53.8 | 643 | |
| Mathematical Reasoning | SVAMP | Accuracy90.4 | 368 | |
| Mathematical Reasoning | ASDIV | Accuracy0.931 | 221 | |
| Mathematical Reasoning | MAWPS | Accuracy97.3 | 219 | |
| Mathematical Reasoning | TabMWP | Accuracy84.7 | 157 | |
| Mathematical Reasoning | AQUA | Accuracy65.4 | 132 | |
| Mathematical Reasoning | SAT Math | SAT Math Accuracy84.4 | 44 | |
| Natural Language Reasoning | GSM8K, MATH, SVAMP, ASDiv, MAWPS, CARP | Average Score79.3 | 29 | |
| Mathematical Reasoning | CARP | Accuracy52.3 | 29 |