投机采样(Speculative Decoding):LLM推理加速的新范式
大语言模型的自回归生成范式要求逐个生成 token,每一步都需要一次完整的前向传播。这个”串行瓶颈”使得推理延迟难以突破。投机采样(Speculative Decoding) 提供了一种全新的加速思路——让一个小模型先去”猜测”多个 token,再由大模型一次性验证。
一、核心原理
1.1 为什么自回归生成很慢?
传统自回归生成过程:
Step 1: 输入 → [LLM] → token A
Step 2: [A] → [LLM] → token B
Step 3: [A,B] → [LLM] → token C
Step 4: [A,B,C] → [LLM] → token D
...
每个步骤:调用一次完整 LLM,生成 1 个 token
=============================================
N 个 token 需要 N 次前向传播
1.2 投机采样的直觉
"""
核心洞察:
- LLM 的 Forward 计算是"批次友好"的
- 一次处理 1 个 token 和一次处理 5 个 token,耗时几乎相同
- 如果能让"一次生成 5 个 token"也保证质量,就能提速 5 倍
"""
# 传统方式:串行生成 5 个 token
# time = 5 × latency_per_token
#
# 投机采样:猜 5 个 → 并行验证 5 个
# time = latency(draft_5) + latency(verify_5)
# ≈ 0.1 × T + T ≈ 1.1T (理想情况提速 4.5 倍)
1.3 算法流程
┌──────────────────────────────────────────┐
│ Speculative Decoding 流程 │
├──────────────────────────────────────────┤
│ │
│ 输入: "人工智能的核" │
│ │ │
│ ┌──────────┴──────────┐ │
│ ▼ ▼ │
│ Draft Model (小) Target LLM (大) │
│ │ │ │
│ 猜测 5 个 token │ │
│ "心技术是深度" │ │
│ │ │ │
│ └──────────┬──────────┘ │
│ ▼ │
│ 并行验证所有猜测位置 │
│ │ │
│ ▼ │
│ 接受匹配的 prefix: "心技术是" │
│ 拒绝不匹配的: "深度" │
│ │ │
│ ▼ │
│ 输出: "心技术是"... │
│ 这次迭代生成了 4 个正确 token! │
│ │
└──────────────────────────────────────────┘
1.4 拒绝采样机制
import torch
def speculative_decoding_step(
draft_model,
target_model,
input_ids,
num_speculate=5,
temperature=1.0
):
"""
一步投机采样:
1. Draft model 猜测 num_speculate 个 token
2. Target model 并行验证
3. Rejection sampling 决定接受哪些
"""
# Step 1: Draft model 快速猜测
draft_tokens = []
draft_probs = []
current = input_ids.clone()
for _ in range(num_speculate):
with torch.no_grad():
logits = draft_model(current)
probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs)
current = torch.cat([current, next_token], dim=-1)
# Step 2: Target model 并行验证
# 将 input_ids + 所有 draft 猜测一起输入
all_input = torch.cat([input_ids] + draft_tokens, dim=-1)
target_logits = target_model(all_input)
# Step 3: Rejection sampling
accepted = []
for i in range(num_speculate):
q = draft_probs[i] # draft 模型分布
p = torch.softmax(
target_logits[:, input_ids.shape[-1] + i, :] / temperature,
dim=-1
) # 目标模型分布
# Rejection sampling 条件
x = draft_tokens[i]
if torch.rand(1) < (p[0, x].item() / q[0, x].item()).clip(0, 1):
accepted.append(x)
else:
# 从调整后的分布重新采样
adjusted = (p - q).clip(min=0)
adjusted /= adjusted.sum()
replacement = torch.multinomial(adjusted, num_samples=1)
accepted.append(replacement)
break # 一旦拒绝,停止接受后续 token
return torch.cat(accepted, dim=-1)
二、Draft Model 选择策略
2.1 三种主流方案
| 方案 | 描述 | 加速比 | 额外成本 | 适用场景 |
|---|---|---|---|---|
| 独立小模型 | 使用独立的小模型(如 125M 参数) | 2-3x | 模型加载 | 通用场景 |
| 模型剪枝版 | 使用目标模型的剪枝/蒸馏版本 | 2-2.5x | 训练成本 | 首次部署 |
| 模型自身(SSD) | 使用目标模型的浅层 + 额外 heads | 2.2-2.8x | 训练 | 生产环境 |
2.2 Draft 模型选型对比
| Draft 模型 | 参数 | 延迟 | 接受率 | 相对加速 | 部署复杂度 |
|---|---|---|---|---|---|
| TinyLlama | 1.1B | 2ms | 65% | 2.1x | 低 |
| Medusa | - | 3ms | 75% | 2.8x | 中(需训练) |
| Self-Speculative | - | 1ms | 60% | 2.5x | 高 |
| Eagle/2 | 1.3B | 2.5ms | 70% | 2.4x | 中 |
2.3 接受率(Acceptance Rate)
接受率是投机采样的核心指标,决定了实际加速效果:
理论加速比 ≈ num_speculate × acceptance_rate
实际加速比 ≈ (num_speculate × acc_rate) / (1 + draft_overhead_ratio)
例子:
- num_speculate = 5
- acc_rate = 0.7
- draft_overhead = 0.15 (draft 耗时是 target 的 15%)
- 理论:5 × 0.7 = 3.5x
- 实际:3.5 / 1.15 ≈ 3.0x
三、主流实现
3.1 vLLM 中的投机采样
# vLLM 投机采样启动
vllm serve meta-llama/Llama-2-7b-hf \
--speculative-model "JackFram/llama-68m" \
--num-speculative-tokens 5 \
--ngram-prompt-lookup-max 4 \
--speculative-draft-tensor-parallel-size 1
关键参数:
| 参数 | 说明 | 建议值 |
|---|---|---|
--num-speculative-tokens | 每次猜测的 token 数 | 3-8 |
--speculative-model | draft 模型路径 | 匹配词表的轻量模型 |
--ngram-prompt-lookup-max | N-gram 匹配长度(替代 draft 模型) | 1-4 |
3.2 Medusa 实现
Medusa 通过在目标模型最后添加多个”预测头”来实现投机采样:
# Medusa 架构示意
class MedusaHead(torch.nn.Module):
"""多候选预测头"""
def __init__(self, hidden_size, vocab_size, num_heads=5):
super().__init__()
self.num_heads = num_heads
# 为每个预测位置训练独立的 head
self.heads = torch.nn.ModuleList([
torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, vocab_size)
)
for _ in range(num_heads)
])
def forward(self, hidden_states):
"""并行预测多个候选 token"""
predictions = []
for head in self.heads:
logits = head(hidden_states[:, -1:, :]) # 只取最后一个位置
predictions.append(logits)
return torch.cat(predictions, dim=1)
3.3 基于 N-gram 的轻量投机
无需单独的 draft 模型,利用 prompt 中的重复模式:
def ngram_speculate(input_ids, n=3, num_predict=5):
"""
N-gram based speculative decoding
利用输入中的 n-gram 匹配来生成候选 token
"""
candidates = []
seq = input_ids.tolist()
for _ in range(num_predict):
# 查找当前 n-gram 的历史匹配
pattern = seq[-(n-1):]
matched = False
for i in range(len(seq) - n):
if seq[i:i+n-1] == pattern:
# 找到匹配,取下一个 token
candidates.append(seq[i+n-1])
seq.append(seq[i+n-1])
matched = True
break
if not matched:
break
return torch.tensor([candidates])
四、性能评估
4.1 不同场景的加速效果
| 场景 | 模型 | Draft | 猜5个 | 猜3个 | 猜8个 |
|---|---|---|---|---|---|
| 代码生成 | CodeLlama-7B | 68M | 2.8x | 2.1x | 3.0x |
| 文本摘要 | Llama-2-13B | 125M | 2.5x | 2.0x | 2.6x |
| 对话 | Mistral-7B | TinyLlama | 2.3x | 1.9x | 2.4x |
| 翻译 | NLLB-3B | NLLB-600M | 2.0x | 1.7x | 2.1x |
4.2 与基线技术对比
| 技术 | 延迟降低 | 吞吐量提升 | 需要训练 | 质量损失 |
|---|---|---|---|---|
| Speculative Decoding | 50-67% | 2-3x | 可选(draft) | 无(数学保证) |
| KV Cache Quantization | 10-20% | 1.1-1.2x | 否 | 轻微 |
| FlashAttention | 20-30% | 1.2-1.3x | 否 | 无 |
| TensorRT Compilation | 30-50% | 1.5-2x | 否 | 无 |
| 模型量化 (INT8) | 40-60% | 1.5-2x | 否 | 轻微 |
五、生产部署建议
5.1 适用场景
✅ 最适合:
- 流式文本生成(chat, 写作辅助)
- 代码补全
- 需要低 P50 延迟的场景
❌ 不太适合:
- 仅 1-2 个 token 的简单分类任务(开销不值得)
- 计算极度受限的部署(draft 模型也占用资源)
- 长 prompt 但短输出的场景
5.2 调优指南
# 根据场景选择猜 token 数量
def find_optimal_spec_len(draft_model, target_model,
acceptance_rate_fn, target_delay_ms=50):
"""
通过实验找到最优猜测数量
"""
results = []
for k in range(1, 11):
delay_draft = measure_draft_latency(draft_model, k)
delay_target = measure_target_verify_latency(target_model, k)
acc_rate = acceptance_rate_fn(k)
total_delay = delay_draft + delay_target
tokens_per_step = k * acc_rate
results.append({
'k': k,
'delay': total_delay,
'tokens_per_step': tokens_per_step,
'effective_speedup': tokens_per_step / total_delay
})
# 选择延迟阈值内,tokens_per_step 最大的 k
feasible = [r for r in results if r['delay'] <= target_delay_ms]
return max(feasible, key=lambda r: r['tokens_per_step'])
5.3 配置 Checklist
- 确定 draft model 是否必要(N-gram 替代方案)
- 测试不同 num_speculative_tokens 的加速比
- 验证输出质量与 baseline 一致(数学保证 100% 一致)
- 监控 acceptance rate 指标
- 配置 fallback 机制(draft 失败时回退普通解码)
- 考虑 prefix caching 与投机采样的组合优化
六、常见挑战与解决方案
| 挑战 | 原因 | 解决方案 |
|---|---|---|
| 接受率偏低 | Draft 模型能力不足 | 替换更大的 draft 模型;减少猜测数 |
| 显存翻倍 | 同时加载两个模型 | 使用 N-gram 投机;共享模型主干 |
| 延迟不稳定 | Draft 被拒绝时回退 | 设置最大回退限制;使用动态猜测数 |
| 上下文敏感 | 复杂推理场景接受率骤降 | 降级到普通解码;使用 Medusa |
总结
投机采样打破了”自回归生成必须串行”的物理限制,通过”小模型推测 + 大模型验证”的范式,在不牺牲输出质量的前提下实现 2-3 倍的推理加速。随着 draft 模型方案的成熟(Medusa、Eagle、Self-Speculative)和主流推理框架(vLLM、TGI、TensorRT-LLM)的原生支持,投机采样已从学术研究走向生产实践,成为 LLM 推理加速的重要武器。