Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Efficient Joint Prediction of Multiple Future Tokens

About

In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.

Kwangjun Ahn, Alex Lamb, John Langford• 2025

Related benchmarks

TaskDatasetResultRank
Language ModelingLAMBADA
Accuracy41.37
412
Multiple-choice Question AnsweringARC Easy
Accuracy68.86
257
Multiple-choice Question AnsweringHellaSwag
Accuracy57.43
196
Social Interaction Question AnsweringSIQA
Accuracy43.35
157
Language ModelingFineWeb-Edu
PPL11.08
141
Multiple-choice Question AnsweringARC Challenge
Acc39.25
133
Multiple-choice Question AnsweringSciQ
Accuracy87.3
91
Language ModelingWikiText
Wikitext PPL19.28
87
Multiple-choice Question AnsweringPIQA
Accuracy74.92
63
Multiple-choice Question AnsweringWinoG
Accuracy59.98
48
Showing 10 of 19 rows

Other info

Follow for update