πŸ’­ Speculative Decoding

πŸ’­ Speculative Decoding#

πŸ“ Overview#

One existing challenge of LLM inference is the latency. As LLMs autoregressively generate the output token by token, the decoding process is largely bottlenecked by the memory bandwidth, i.e. the inference engine needs to load the whole model weights into memory for each token generation. The idea of speculative decoding stems from the thought that we can use a small model to predict the next few tokens in advance and let our main model to verify these tokens in sequence. As the decoding process is memory-bound, the time taken to verify multiple tokens is comparable to the time taken to generate a single token. In this way, we can speed up the decoding process significantly by speculating the next few tokens in advance.

πŸ”§ How it works?#

In speculative decoding, we have two models:

  1. Target Model: a large model that is intended to serve the users, e.g. the model you want to deploy for production.

  2. Draft Model: a small model that is trained to predict the next few tokens in advance. This can be in various forms, e.g. an n-gram model, a pretrained small language model (often from the same model family), a separately trained small model (EAGLE).

Drafting
Source: Blog by NVIDIA

The role of the draft model is to predict the next few tokens in advance, and the role of the target model is to verify the tokens predicted by the draft model. As shown in the animation above, the workflow of speculative decoding can be decomposed into 3 stages:

  • prefill: the target model will first take the prompt as the input and run the prefill stage.

  • drafting: Afterwards, we let the draft model to iteratively predict the next N candidate tokens. Since the draft model is often much smaller than the target model, the drafting time is insignificant.

  • verification: We then pass the N candidate tokens to the target model to verify in parallel. Since this stage is memory-bound, it does not increase the latency significantly by increasing the number of tokens. If a token is accepted by the target model, it will be added to the output sequence, otherwise, it will be discarded. The draft model will continue to predict the next tokens based on the last accepted token and this process will repeat until the end of the sequence is reached.

One advantage of speculative decoding is that it guarantees the output distribution is the same as that of using the target model alone. This is because the target model will decide the acceptance of the candidate tokens using rejection sampling. The speculative paper has provided a mathematical proof for its correctness in the appendix section. In simple words, it will only accept the candidate tokens that are most likely to be correct. Let’s notate the probability of a token generated by the target model as \(p(x)\) and the probability of a token generated by the draft model as \(q(x)\). If \(q(x) < p(x)\), then the token will be accepted. If \(q(x) > p(x)\), the target model will reject the token with probability \(1 - p(x)/q(x)\) and sample a new token from the distribution \(p'(x) = \text{norm}(max(0, p(x) - q(x)))\). Below shows an animation of the verification process.

Verification
Source: Blog by NVIDIA