AutoJudge: Judge Decoding Without Manual Annotation
About
We introduce AutoJudge, a method that accelerates large language model (LLM) inference with task-specific lossy speculative decoding. Instead of matching the original model output distribution token-by-token, we identify which of the generated tokens affect the downstream quality of the response, relaxing the distribution match guarantee so that the "unimportant" tokens can be generated faster. Our approach relies on a semi-greedy search algorithm to test which of the mismatches between target and draft models should be corrected to preserve quality and which ones may be skipped. We then train a lightweight classifier based on existing LLM embeddings to predict, at inference time, which mismatching tokens can be safely accepted without compromising the final answer quality. We evaluate the effectiveness of AutoJudge with multiple draft/target model pairs on mathematical reasoning and programming benchmarks, achieving significant speedups at the cost of a minor accuracy reduction. Notably, on GSM8k with the Llama 3.1 70B target model, our approach achieves up to $\approx2\times$ speedup over speculative decoding at the cost of $\le 1\%$ drop in accuracy. When applied to the LiveCodeBench benchmark, AutoJudge automatically detects programming-specific important tokens, accepting $\ge 25$ tokens per speculation cycle at $2\%$ drop in Pass@1. Our approach requires no human annotation and is easy to integrate with modern LLM inference frameworks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K | Accuracy (Acc)80.1 | 337 | |
| Mathematical Reasoning | GSM8K v1 (test) | Accuracy92.77 | 118 | |
| Mathematical Reasoning | GSM8K 8-shot | Accuracy95.1 | 89 | |
| Mathematical Reasoning | MATH 500 | Accuracy41 | 79 | |
| General Knowledge | MMLU | -- | 25 | |
| Code Generation | LiveCodeBench (train) | Metric m Score19.63 | 6 | |
| Summarization | CNN/DM | ROUGE (m)4.74 | 6 | |
| Summarization | CNN/DailyMail | Metric m4.74 | 6 | |
| Multi-task Performance Change Analysis | GSM8K, MATH-500, MMLU, CNN/DM | Mean Delta1.7 | 5 | |
| Data Generation | GSM8K, Live Code Bench, and Dolly15k 53K labels | Generation Time (hours)120 | 2 |