简要介绍投机解码的原理,源自Google 2023的paper
经典decoding
最基本的decoding方式(自回归,无KV cache,仅作示意)
1 | for i in ...: |
这里的问题是受限于LLM典型结构,一次forward只能推一个token。结构上完全可以一次推多个(从保留最后一个token变成保留最后几个),但这样输入部分最后几个token没法确定,因此无法计算。
如果可以从预训练阶段下手,我们或可考虑Multi-Token Prediction(这是另一个范畴了)。但如果我们只能从推理阶段下手(例如无法做训练,但有LLM推理出的logits信息;有时也可能只有LLM api渠道,只有decode结果),则投机解码(speculative decoding)就很有价值。
speculative decoding
核心原理是用一个“小”LLM(当然,“小”到什么程度是有一定界限的,选择不当有概率负优化)帮忙先前推n(通常3-5)个token,然后把原始输入和小LLM推的这几个token都作为input给LLM,LLM通过一定的机制一次性验证小LLM的多个token是否正确。

验证方式:reject sampling
定义小LLM、大LLM推出的概率分布是q,p,小模型推理出的某个token是x,则paper里给出的最经典的拒绝方式是:
- 如果,即大LLM对这个token的置信度比小LLM还大,就直接接受
- 如果,则仅以 的概率接受。也就是说,如果大LLM对这个token的置信度小,同时小LLM的置信度反常地大,就以更高的概率拒绝。
- 拒绝后,x这个token以及小LLM在这个位置后给出的token都全部丢弃。
- 但为了避免浪费大LLM这一轮的推理,我们需要利用大LLM现有的推理结果,把x这个位置正确的token推理出来
- 需要注意的是,即使小LLM前面的decode结果是正确的,但logits其实是有偏差的(才导致在x这个位置的预测被拒绝),因此不能直接用大LLM在x这个位置推理得到的分布,我们需要修正这个分布。
- 可以构造出一种分布修正形式(残差分布),使最终在x这个地方的总体概率分布是和原始大模型直接推理是一致的。
Google的paper是顺向证明,以下我从反向直接推导残差分布,还原研究时的状态
更严格地定义:X代表上面提到的这个token,;代表用上面的方法得到的token。
根据上面的规则,有两条可达路径
- 小LLM推出的结果采样得到,且最终接受
- 小LLM推出的结果被拒绝,但最后通过残差分布再次采样得到
于是(提示:是已经得到小LLM特定输出后的总体接受概率;我们可以假设残差分布不依赖小LLM的特定输出,因为这样也能解出来具体形式)
由于我们希望
于是解出残差分布的具体形式
注意:分母和google paper里的有点差别,但你仔细推一下其实二者是相等的
算法
可能有些边界条件不太对
注意Bonus Token,其实是100%正提升
1 | for i in ...: # 主循环,i 是当前主序列的末尾指针 |
典型的改进
Greedy Sampling
和原版除了temperature=0外,其他都一样,最终等价于直接对比decode字符串是否一致,不一致就拒绝
注意前面的证明同样适用,因此也是无损的(概率分布未被影响)。
这种方法不需要prob,只要有LLM API就能用,目测是最佳实践
Medusa
不单独设置小模型,而是在LLM末端接一些layer,用LLM+layer自身充当自己的小LLM
缺点是layer需要自己训(如果官方没提供)
EAGLE
类似Medusa,加上了
- 预测树
- layer变成了特征层的自回归
这种额外权重通常参数量在1%以内,因此也不算难训
N-gram投机
直接把近期文本切成N-gram字典,然后匹配历史文本作为draft
既不需要训练,也不需要额外准备小模型,开箱即用
Suffix Decoding
可以看成高配版N-gram,只不过通过会话历史生成后缀树来生成draft
同样既不需要训练,也不需要额外准备小模型,开箱即用,而且通常可以用来利用CPU的算力。属于近期快速上线的最佳实践。