Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Next-Latent Prediction Transformers Learn Compact World Models

About

Transformers replace recurrence with a memory that grows with sequence length and self-attention that enables ad-hoc lookups over past tokens. Consequently, they lack an inherent incentive to compress history into compact latent states with consistent transition rules. This often leads to learning solutions that generalize poorly. We introduce Next-Latent Prediction (NextLat), which extends standard next-token training with self-supervised predictions in the latent space. Specifically, NextLat trains a transformer to learn latent representations that are predictive of its next latent state given the next token. Theoretically, we show that these latents provably converge towards belief states, compressed information about the history necessary to predict the future. This simple auxiliary objective injects a recurrent inductive bias into transformers while leaving their architecture, parallel training efficiency, and inference unchanged. NextLat effectively encourages transformers to form compact internal world models with coherent belief states and transition dynamics -- crucial properties not guaranteed by standard next-token prediction alone. Empirically, across benchmarks in world modeling, reasoning, planning, and language modeling, NextLat demonstrates significant gains over standard next-token prediction and other baselines in downstream accuracy, representation compression, and lookahead planning. Furthermore, NextLat enables variable-length self-speculative decoding, accelerating inference by up to 3.3x in language modeling. NextLat offers a simple yet effective paradigm for learning compact, predictive representations in transformers that generalize better. Our code is available at https://github.com/microsoft/NextLat.

Jayden Teoh, Manan Tomar, Kwangjun Ahn, Edward S. Hu, Tim Pearce, Pratyusha Sharma, Akshay Krishnamurthy, Riashat Islam, Alex Lamb, John Langford• 2025

Related benchmarks

TaskDatasetResultRank
Language ModelingLAMBADA
Accuracy43.86
412
Multiple-choice Question AnsweringARC Easy
Accuracy69.74
257
Multiple-choice Question AnsweringHellaSwag
Accuracy58.35
196
Social Interaction Question AnsweringSIQA
Accuracy43.24
157
Language ModelingFineWeb-Edu
PPL10.83
141
Multiple-choice Question AnsweringARC Challenge
Acc40.1
133
Multiple-choice Question AnsweringSciQ
Accuracy87.5
91
Language ModelingWikiText
Wikitext PPL18.39
87
Multiple-choice Question AnsweringPIQA
Accuracy73.61
63
Multiple-choice Question AnsweringWinoG
Accuracy59.27
48
Showing 10 of 19 rows

Other info

Follow for update