Efficient Joint Prediction of Multiple Future Tokens
About
In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | LAMBADA | Accuracy41.37 | 412 | |
| Multiple-choice Question Answering | ARC Easy | Accuracy68.86 | 257 | |
| Multiple-choice Question Answering | HellaSwag | Accuracy57.43 | 196 | |
| Social Interaction Question Answering | SIQA | Accuracy43.35 | 157 | |
| Language Modeling | FineWeb-Edu | PPL11.08 | 141 | |
| Multiple-choice Question Answering | ARC Challenge | Acc39.25 | 133 | |
| Multiple-choice Question Answering | SciQ | Accuracy87.3 | 91 | |
| Language Modeling | WikiText | Wikitext PPL19.28 | 87 | |
| Multiple-choice Question Answering | PIQA | Accuracy74.92 | 63 | |
| Multiple-choice Question Answering | WinoG | Accuracy59.98 | 48 |