Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

SpecTr: Fast Speculative Decoding via Optimal Transport

About

Autoregressive sampling from large language models has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is $\textit{speculative decoding}$: use a small model to sample a $\textit{draft}$ (block or sequence of tokens), and then score all tokens in the draft by the large language model in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with $\textit{membership cost}$. This framework can be viewed as an extension of the well-known $\textit{maximal-coupling}$ problem. This new formulation enables us to generalize the speculative decoding method to allow for a set of $k$ candidates at the token-level, which leads to an improved optimal membership cost. We show that the optimal draft selection algorithm (transport plan) can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose a valid draft selection algorithm whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of domain of a single token. Using this $new draft selection$ algorithm, we develop a new autoregressive sampling algorithm called $\textit{SpecTr}$, which provides speedup in decoding while ensuring that there is no quality degradation in the decoded output. We experimentally demonstrate that for state-of-the-art large language models, the proposed approach achieves a wall clock speedup of 2.13X, a further 1.37X speedup over speculative decoding on standard benchmarks.

Ziteng Sun, Ananda Theertha Suresh, Jae Hun Ro, Ahmad Beirami, Himanshu Jain, Felix Yu• 2023

Related benchmarks

TaskDatasetResultRank
Multilingual Mathematical ReasoningMGSM (test)--
109
Instruction FollowingAlpaca (test)
SR Score1.52
21
Language ModelingLM1B (test)
Block Efficiency8.93
15
Python ProgrammingHumanEval (test)
BE Score10.25
10
Speculative DecodingHumanEval (test)
BE7.99
10
Speculative DecodingGSM8K (test)
BE Score7.89
10
Speculative DecodingMGSM (test)
BE7.36
10
Speculative DecodingLM1B (test)
BE7.81
10
Speculative DecodingAlpaca (test)
BE Score7.56
10
grade-school mathGSM8K (test)
BE Score6.55
10
Showing 10 of 10 rows

Other info

Follow for update