Latent Reasoning with Supervised Thinking States
About
Reasoning with a chain-of-thought (CoT) enables Large Language Models (LLMs) to solve complex tasks but incurs significant inference costs due to the generation of long rationales. We propose Thinking States, a method that performs reasoning {\em while} the input is processing. Specifically, Thinking States generates sequences of thinking tokens every few input tokens, transforms the thoughts back into embedding space, and adds them to the following input tokens. This has two key advantages. First, it captures the recurrent nature of CoT, but where the thought tokens are generated as input is processing. Second, since the thoughts are represented as tokens, they can be learned from natural language supervision, and using teacher-forcing, which is parallelizable. Empirically, Thinking States outperforms other latent reasoning methods on multiple reasoning tasks, narrowing the gap to CoT on math problems, and matching its performance on 2-Hop QA with improved latency. On state-tracking tasks, we show Thinking States leads to stronger reasoning behavior than CoT, successfully extrapolating to longer sequences than seen during training.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Mathematical Reasoning | GSM8K | Speed Up (x)2.66 | 177 | |
| Question Answering | 2-Hop QA Full Knowledge | Acc54.91 | 5 | |
| Question Answering | 2-Hop QA Parametric Knowledge | Accuracy43.05 | 5 |