0%

speculative decoding

简要介绍投机解码的原理,源自Google 2023的paper

经典decoding

最基本的decoding方式(自回归,无KV cache,仅作示意)

1
2
for i in ...:
x[i] = LLM(x[:i])[-1]

这里的问题是受限于LLM典型结构,一次forward只能推一个token。结构上完全可以一次推多个(从保留最后一个token变成保留最后几个),但这样输入部分最后几个token没法确定,因此无法计算。

如果可以从预训练阶段下手,我们或可考虑Multi-Token Prediction(这是另一个范畴了)。但如果我们只能从推理阶段下手(例如无法做训练,但有LLM推理出的logits信息;有时也可能只有LLM api渠道,只有decode结果),则投机解码(speculative decoding)就很有价值。

speculative decoding

核心原理是用一个“小”LLM(当然,“小”到什么程度是有一定界限的,选择不当有概率负优化)帮忙先前推n(通常3-5)个token,然后把原始输入和小LLM推的这几个token都作为input给LLM,LLM通过一定的机制一次性验证小LLM的多个token是否正确。

image-20260613150643263

验证方式:reject sampling

定义小LLM、大LLM推出的概率分布是q,p,小模型推理出的某个token是x,则paper里给出的最经典的拒绝方式是:

  • 如果p(x)q(x)p(x) \ge q(x),即大LLM对这个token的置信度比小LLM还大,就直接接受
  • 如果p(x)<q(x)p(x) < q(x),则仅以 p(x)/q(x)p(x)/q(x)的概率接受。也就是说,如果大LLM对这个token的置信度小,同时小LLM的置信度反常地大,就以更高的概率拒绝。
    • 拒绝后,x这个token以及小LLM在这个位置后给出的token都全部丢弃。
    • 但为了避免浪费大LLM这一轮的推理,我们需要利用大LLM现有的推理结果,把x这个位置正确的token推理出来
    • 需要注意的是,即使小LLM前面的decode结果是正确的,但logits其实是有偏差的(才导致在x这个位置的预测被拒绝),因此不能直接用大LLM在x这个位置推理得到的分布,我们需要修正这个分布。
    • 可以构造出一种分布修正形式(残差分布rr),使最终在x这个地方的总体概率分布是和原始大模型直接推理是一致的。

Google的paper是顺向证明,以下我从反向直接推导残差分布,还原研究时的状态

更严格地定义:X代表上面提到的这个token,p(x)=P(X=x)p(x)=P(X=x)YY代表用上面的方法得到的token。

根据上面的规则,P(Y=y)P(Y=y)有两条可达路径

  • 小LLM推出的结果采样得到yy,且最终接受
  • 小LLM推出的结果被拒绝,但最后通过残差分布rr再次采样得到yy

于是(提示:min(1,p(x)q(x))\min(1,\frac{p(x)}{q(x)})是已经得到小LLM特定输出后的总体接受概率;我们可以假设残差分布不依赖小LLM的特定输出,因为这样也能解出来具体形式)

P(Y=y,accepted)=q(y)min(1,p(y)q(y))=min(q(y),p(y))P(Y=y,accepted)=q(y)\min(1,\frac{p(y)}{q(y)})=\min(q(y),p(y))

P(Y=y,rejected)=xq(x)(1min(1,p(x)q(x)))r(y)=r(y)x(q(x)min(q(x),p(x)))P(Y=y,rejected)=\sum_x q(x)(1-\min(1,\frac{p(x)}{q(x)}))r(y)=r(y)\sum_x(q(x)-\min(q(x),p(x)))

由于我们希望P(Y=y)=P(Y=y,accepted)+P(Y=y,rejected)=P(X=y)=p(y)P(Y=y)=P(Y=y,accepted)+P(Y=y,rejected)=P(X=y)=p(y)

于是解出残差分布的具体形式

r(y)=p(y)min(q(y),p(y))xq(x)min(q(x),p(x))r(y)=\frac{p(y)-\min(q(y),p(y))}{\sum_x q(x)-\min(q(x),p(x))}

注意:分母和google paper里的有点差别,但你仔细推一下其实二者是相等的

算法

可能有些边界条件不太对

注意Bonus Token,其实是100%正提升

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
for i in ...:  # 主循环,i 是当前主序列的末尾指针
# 1. 小模型一口气盲猜 n-1 步 (这里假设草稿长度为 n-1)
for j in range(i, i + n - 1):
x_spec.sampled[j] = miniLLM(concat(x[:i], x_spec[i:j]))[-1].sample()

# 2. 大模型一次性验证,并行产出 n 个位置的概率分布
x_probs[i : i + n] = LLM(concat(x[:i], x_spec[i : i + n - 1]))[-n:]

# 3. 逐字审核机制
for k in range(i, i + n):
if k == i + n - 1:
# 走到最后一步了,说明前面的草稿全对!
# 顺手带走大模型在最后这个空位上独立预测的 Bonus Token
x[k].sampled = sample(x_probs[k])
i = k + 1
break

token_draft = x_spec.sampled[k] # 小模型给的草稿字
qk = x_spec_probs[k][token_draft] # 小模型对它的自信度
pk = x_probs[k][token_draft] # 大模型对它的认可度,特别注意这里用的是小模型预测结果的prob

if pk >= qk or rand(0, 1) < (pk / qk):
x[k].sampled = token_draft # 接受,正式录用
continue
else:
# 拒绝!启动全局残差补偿机制
res_dist = res_dist_fn(x_probs[k], x_spec_probs[k])
x[k].sampled = sample(res_dist)
# 此时 x 序列在 k 之后的草稿自动作废,下一轮主循环从 k+1 开始
i = k + 1
break

典型的改进

Greedy Sampling

和原版除了temperature=0外,其他都一样,最终等价于直接对比decode字符串是否一致,不一致就拒绝

注意前面的证明同样适用,因此也是无损的(概率分布未被影响)。

这种方法不需要prob,只要有LLM API就能用,目测是最佳实践

Medusa

不单独设置小模型,而是在LLM末端接一些layer,用LLM+layer自身充当自己的小LLM

缺点是layer需要自己训(如果官方没提供)

EAGLE

类似Medusa,加上了

  • 预测树
  • layer变成了特征层的自回归

这种额外权重通常参数量在1%以内,因此也不算难训

N-gram投机

直接把近期文本切成N-gram字典,然后匹配历史文本作为draft

既不需要训练,也不需要额外准备小模型,开箱即用

Suffix Decoding

可以看成高配版N-gram,只不过通过会话历史生成后缀树来生成draft

同样既不需要训练,也不需要额外准备小模型,开箱即用,而且通常可以用来利用CPU的算力。属于近期快速上线的最佳实践。

欢迎关注我的其它发布渠道