LLM工程化 高级 投机采样 Speculative Decoding 推理加速 Draft Model

投机采样(Speculative Decoding):LLM推理加速的新范式

AIEng Hub
阅读约 20 分钟

投机采样(Speculative Decoding):LLM推理加速的新范式

大语言模型的自回归生成范式要求逐个生成 token,每一步都需要一次完整的前向传播。这个”串行瓶颈”使得推理延迟难以突破。投机采样(Speculative Decoding) 提供了一种全新的加速思路——让一个小模型先去”猜测”多个 token,再由大模型一次性验证。

一、核心原理

1.1 为什么自回归生成很慢?

传统自回归生成过程:
Step 1: 输入 → [LLM] → token A
Step 2: [A] → [LLM] → token B
Step 3: [A,B] → [LLM] → token C
Step 4: [A,B,C] → [LLM] → token D
...

每个步骤:调用一次完整 LLM,生成 1 个 token
=============================================
N 个 token 需要 N 次前向传播

1.2 投机采样的直觉

"""
核心洞察:
- LLM 的 Forward 计算是"批次友好"的
- 一次处理 1 个 token 和一次处理 5 个 token,耗时几乎相同
- 如果能让"一次生成 5 个 token"也保证质量,就能提速 5 倍
"""

# 传统方式:串行生成 5 个 token
# time = 5 × latency_per_token
# 
# 投机采样:猜 5 个 → 并行验证 5 个
# time = latency(draft_5) + latency(verify_5)
#      ≈ 0.1 × T + T ≈ 1.1T (理想情况提速 4.5 倍)

1.3 算法流程

┌──────────────────────────────────────────┐
│        Speculative Decoding 流程          │
├──────────────────────────────────────────┤
│                                          │
│  输入: "人工智能的核"                      │
│                    │                      │
│         ┌──────────┴──────────┐           │
│         ▼                     ▼           │
│   Draft Model (小)      Target LLM (大)   │
│         │                     │           │
│   猜测 5 个 token            │           │
│   "心技术是深度"             │           │
│         │                     │           │
│         └──────────┬──────────┘           │
│                    ▼                      │
│  并行验证所有猜测位置                      │
│         │                                │
│         ▼                                │
│  接受匹配的 prefix: "心技术是"            │
│  拒绝不匹配的: "深度"                     │
│         │                                │
│         ▼                                │
│  输出: "心技术是"...                      │
│  这次迭代生成了 4 个正确 token!           │
│                                          │
└──────────────────────────────────────────┘

1.4 拒绝采样机制

import torch

def speculative_decoding_step(
    draft_model,
    target_model,
    input_ids,
    num_speculate=5,
    temperature=1.0
):
    """
    一步投机采样:
    1. Draft model 猜测 num_speculate 个 token
    2. Target model 并行验证
    3. Rejection sampling 决定接受哪些
    """
    # Step 1: Draft model 快速猜测
    draft_tokens = []
    draft_probs = []
    current = input_ids.clone()
    
    for _ in range(num_speculate):
        with torch.no_grad():
            logits = draft_model(current)
            probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
        
        next_token = torch.multinomial(probs, num_samples=1)
        draft_tokens.append(next_token)
        draft_probs.append(probs)
        current = torch.cat([current, next_token], dim=-1)
    
    # Step 2: Target model 并行验证
    # 将 input_ids + 所有 draft 猜测一起输入
    all_input = torch.cat([input_ids] + draft_tokens, dim=-1)
    target_logits = target_model(all_input)
    
    # Step 3: Rejection sampling
    accepted = []
    for i in range(num_speculate):
        q = draft_probs[i]  # draft 模型分布
        p = torch.softmax(
            target_logits[:, input_ids.shape[-1] + i, :] / temperature, 
            dim=-1
        )  # 目标模型分布
        
        # Rejection sampling 条件
        x = draft_tokens[i]
        if torch.rand(1) < (p[0, x].item() / q[0, x].item()).clip(0, 1):
            accepted.append(x)
        else:
            # 从调整后的分布重新采样
            adjusted = (p - q).clip(min=0)
            adjusted /= adjusted.sum()
            replacement = torch.multinomial(adjusted, num_samples=1)
            accepted.append(replacement)
            break  # 一旦拒绝,停止接受后续 token
    
    return torch.cat(accepted, dim=-1)

二、Draft Model 选择策略

2.1 三种主流方案

方案描述加速比额外成本适用场景
独立小模型使用独立的小模型(如 125M 参数)2-3x模型加载通用场景
模型剪枝版使用目标模型的剪枝/蒸馏版本2-2.5x训练成本首次部署
模型自身(SSD)使用目标模型的浅层 + 额外 heads2.2-2.8x训练生产环境

2.2 Draft 模型选型对比

Draft 模型参数延迟接受率相对加速部署复杂度
TinyLlama1.1B2ms65%2.1x
Medusa-3ms75%2.8x中(需训练)
Self-Speculative-1ms60%2.5x
Eagle/21.3B2.5ms70%2.4x

2.3 接受率(Acceptance Rate)

接受率是投机采样的核心指标,决定了实际加速效果:

理论加速比 ≈ num_speculate × acceptance_rate
实际加速比 ≈ (num_speculate × acc_rate) / (1 + draft_overhead_ratio)

例子:
- num_speculate = 5
- acc_rate = 0.7
- draft_overhead = 0.15 (draft 耗时是 target 的 15%)
- 理论:5 × 0.7 = 3.5x
- 实际:3.5 / 1.15 ≈ 3.0x

三、主流实现

3.1 vLLM 中的投机采样

# vLLM 投机采样启动
vllm serve meta-llama/Llama-2-7b-hf \
    --speculative-model "JackFram/llama-68m" \
    --num-speculative-tokens 5 \
    --ngram-prompt-lookup-max 4 \
    --speculative-draft-tensor-parallel-size 1

关键参数:

参数说明建议值
--num-speculative-tokens每次猜测的 token 数3-8
--speculative-modeldraft 模型路径匹配词表的轻量模型
--ngram-prompt-lookup-maxN-gram 匹配长度(替代 draft 模型)1-4

3.2 Medusa 实现

Medusa 通过在目标模型最后添加多个”预测头”来实现投机采样:

# Medusa 架构示意
class MedusaHead(torch.nn.Module):
    """多候选预测头"""
    def __init__(self, hidden_size, vocab_size, num_heads=5):
        super().__init__()
        self.num_heads = num_heads
        
        # 为每个预测位置训练独立的 head
        self.heads = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(hidden_size, hidden_size),
                torch.nn.SiLU(),
                torch.nn.Linear(hidden_size, vocab_size)
            )
            for _ in range(num_heads)
        ])
    
    def forward(self, hidden_states):
        """并行预测多个候选 token"""
        predictions = []
        for head in self.heads:
            logits = head(hidden_states[:, -1:, :])  # 只取最后一个位置
            predictions.append(logits)
        return torch.cat(predictions, dim=1)

3.3 基于 N-gram 的轻量投机

无需单独的 draft 模型,利用 prompt 中的重复模式:

def ngram_speculate(input_ids, n=3, num_predict=5):
    """
    N-gram based speculative decoding
    利用输入中的 n-gram 匹配来生成候选 token
    """
    candidates = []
    seq = input_ids.tolist()
    
    for _ in range(num_predict):
        # 查找当前 n-gram 的历史匹配
        pattern = seq[-(n-1):]
        matched = False
        
        for i in range(len(seq) - n):
            if seq[i:i+n-1] == pattern:
                # 找到匹配,取下一个 token
                candidates.append(seq[i+n-1])
                seq.append(seq[i+n-1])
                matched = True
                break
        
        if not matched:
            break
    
    return torch.tensor([candidates])

四、性能评估

4.1 不同场景的加速效果

场景模型Draft猜5个猜3个猜8个
代码生成CodeLlama-7B68M2.8x2.1x3.0x
文本摘要Llama-2-13B125M2.5x2.0x2.6x
对话Mistral-7BTinyLlama2.3x1.9x2.4x
翻译NLLB-3BNLLB-600M2.0x1.7x2.1x

4.2 与基线技术对比

技术延迟降低吞吐量提升需要训练质量损失
Speculative Decoding50-67%2-3x可选(draft)无(数学保证)
KV Cache Quantization10-20%1.1-1.2x轻微
FlashAttention20-30%1.2-1.3x
TensorRT Compilation30-50%1.5-2x
模型量化 (INT8)40-60%1.5-2x轻微

五、生产部署建议

5.1 适用场景

✅ 最适合:
- 流式文本生成(chat, 写作辅助)
- 代码补全
- 需要低 P50 延迟的场景

❌ 不太适合:
- 仅 1-2 个 token 的简单分类任务(开销不值得)
- 计算极度受限的部署(draft 模型也占用资源)
- 长 prompt 但短输出的场景

5.2 调优指南

# 根据场景选择猜 token 数量
def find_optimal_spec_len(draft_model, target_model, 
                          acceptance_rate_fn, target_delay_ms=50):
    """
    通过实验找到最优猜测数量
    """
    results = []
    
    for k in range(1, 11):
        delay_draft = measure_draft_latency(draft_model, k)
        delay_target = measure_target_verify_latency(target_model, k)
        acc_rate = acceptance_rate_fn(k)
        
        total_delay = delay_draft + delay_target
        tokens_per_step = k * acc_rate
        
        results.append({
            'k': k,
            'delay': total_delay,
            'tokens_per_step': tokens_per_step,
            'effective_speedup': tokens_per_step / total_delay
        })
    
    # 选择延迟阈值内,tokens_per_step 最大的 k
    feasible = [r for r in results if r['delay'] <= target_delay_ms]
    return max(feasible, key=lambda r: r['tokens_per_step'])

5.3 配置 Checklist

  • 确定 draft model 是否必要(N-gram 替代方案)
  • 测试不同 num_speculative_tokens 的加速比
  • 验证输出质量与 baseline 一致(数学保证 100% 一致)
  • 监控 acceptance rate 指标
  • 配置 fallback 机制(draft 失败时回退普通解码)
  • 考虑 prefix caching 与投机采样的组合优化

六、常见挑战与解决方案

挑战原因解决方案
接受率偏低Draft 模型能力不足替换更大的 draft 模型;减少猜测数
显存翻倍同时加载两个模型使用 N-gram 投机;共享模型主干
延迟不稳定Draft 被拒绝时回退设置最大回退限制;使用动态猜测数
上下文敏感复杂推理场景接受率骤降降级到普通解码;使用 Medusa

总结

投机采样打破了”自回归生成必须串行”的物理限制,通过”小模型推测 + 大模型验证”的范式,在不牺牲输出质量的前提下实现 2-3 倍的推理加速。随着 draft 模型方案的成熟(Medusa、Eagle、Self-Speculative)和主流推理框架(vLLM、TGI、TensorRT-LLM)的原生支持,投机采样已从学术研究走向生产实践,成为 LLM 推理加速的重要武器。