Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

AdaPonderLM: Gated Pondering Language Models with Token-Wise Adaptive Depth

About

Test-time scaling via recurrent/iterative Transformers enables large language models to spend more computation at inference, but most pretrained recurrent LMs run a fixed number of iterations, wasting compute on easy tokens and lacking token-wise adaptivity. Following the core idea of Adaptive Computation Time(ACT) and Early Exit(EE), we propose AdaPonderLM, a self-supervised recurrent language model that learns token-wise early exiting during pretraining without manually tuned per-token/per-layer pruning ratios. AdaPonderLM uses iteration-specific MLP gates with a monotonic halting mask to decide when each token stops recurring, and introduces a KV reuse mechanism that reuses cached key/value states for halted tokens, ensuring train--test consistency and practical acceleration. Across Pythia backbones from 70M to 410M (pretraining) and up to 2.8B (continued pretraining), AdaPonderLM reduces inference compute at about 10% while maintaining comparable language modeling perplexity and competitive downstream accuracy. Our analysis shows the learned gates allocate more computation to high-NLL (hard) tokens, exhibiting adaptive computation time behavior in a fully self-supervised setting. Meanwhile, under iso-FLOPs, the learned halting policy consistently outperforms fixed pruning, showing AdaPonderLM allocates compute to the right tokens rather than just reducing average depth.

Shixiang Song, He Li, Zitong Wang, Boyi Zeng, Feichen Song, Yixuan Wang, Zhiqin John Xu, Ziwei He, Zhouhan Lin• 2026

Related benchmarks

TaskDatasetResultRank
Commonsense ReasoningWinoGrande
Accuracy63.6
1085
Question AnsweringARC-E
Accuracy70.9
416
Question AnsweringPIQA
Accuracy75.9
374
Question AnsweringSciQ--
283
Sentence CompletionHellaSwag
Accuracy48.9
276
Language ModelingLambada OpenAI
Accuracy68.3
127
Reading ComprehensionRACE
Accuracy38.5
70
Question AnsweringARC-C
Accuracy (ARC-C)35.2
46
Language ModelingLambada Standard
Accuracy59.8
36
Mean Performance EvaluationDownstream Tasks Summary
Average Accuracy61.1
36
Showing 10 of 10 rows

Other info

Follow for update