SpecTr: Fast Speculative Decoding via Optimal Transport
About
Autoregressive sampling from large language models has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is $\textit{speculative decoding}$: use a small model to sample a $\textit{draft}$ (block or sequence of tokens), and then score all tokens in the draft by the large language model in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with $\textit{membership cost}$. This framework can be viewed as an extension of the well-known $\textit{maximal-coupling}$ problem. This new formulation enables us to generalize the speculative decoding method to allow for a set of $k$ candidates at the token-level, which leads to an improved optimal membership cost. We show that the optimal draft selection algorithm (transport plan) can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose a valid draft selection algorithm whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of domain of a single token. Using this $new draft selection$ algorithm, we develop a new autoregressive sampling algorithm called $\textit{SpecTr}$, which provides speedup in decoding while ensuring that there is no quality degradation in the decoded output. We experimentally demonstrate that for state-of-the-art large language models, the proposed approach achieves a wall clock speedup of 2.13X, a further 1.37X speedup over speculative decoding on standard benchmarks.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Multilingual Mathematical Reasoning | MGSM (test) | -- | 109 | |
| Instruction Following | Alpaca (test) | SR Score1.52 | 21 | |
| Language Modeling | LM1B (test) | Block Efficiency8.93 | 15 | |
| Python Programming | HumanEval (test) | BE Score10.25 | 10 | |
| Speculative Decoding | HumanEval (test) | BE7.99 | 10 | |
| Speculative Decoding | GSM8K (test) | BE Score7.89 | 10 | |
| Speculative Decoding | MGSM (test) | BE7.36 | 10 | |
| Speculative Decoding | LM1B (test) | BE7.81 | 10 | |
| Speculative Decoding | Alpaca (test) | BE Score7.56 | 10 | |
| grade-school math | GSM8K (test) | BE Score6.55 | 10 |