LLM工程化 高级 KV Cache PagedAttention vLLM 推理优化

KV Cache 优化:LLM 推理性能提升的关键

AIEng Hub
阅读约 25 分钟

什么是 KV Cache?

KV Cache(Key-Value Cache)是 Transformer 模型推理中的关键优化技术,用于缓存之前计算的 Key 和 Value 矩阵,避免重复计算。

┌─────────────────────────────────────────────────────────────┐
│                    KV Cache 原理                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   没有 KV Cache:                                            │
│   Token 1 → [计算] → Token 2 → [重新计算1+2] → Token 3      │
│   时间复杂度: O(n³)                                         │
│                                                             │
│   使用 KV Cache:                                            │
│   Token 1 → [计算] → 缓存(K1,V1)                            │
│   Token 2 → [使用缓存+计算2] → 缓存(K1,V1,K2,V2)            │
│   Token 3 → [使用缓存+计算3]                                 │
│   时间复杂度: O(n²)                                         │
│                                                             │
│   内存占用:                                                 │
│   - 每层: 2 × num_heads × head_dim × seq_len × batch_size   │
│   - FP16: 每个token约 2 × hidden_size bytes                 │
│   - 7B模型: 每个token约 1-2 MB                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

KV Cache 的内存计算公式:

内存(GB) = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × sizeof(dtype)

例如:
- 模型: Llama-2-7B
- 层数: 32
- 头数: 32
- 头维度: 128
- 序列长度: 4096
- 批大小: 1
- 精度: FP16 (2 bytes)

内存 = 2 × 32 × 32 × 128 × 4096 × 1 × 2 / 1e9 = 2.1 GB

KV Cache 的挑战

1. 内存碎片化

传统分配方式:
┌─────────────────────────────────────────────────┐
│ 请求A(512) │ 请求B(1024) │ 请求C(2048) │ 空闲  │
└─────────────────────────────────────────────────┘

请求B完成后:
┌─────────────────────────────────────────────────┐
│ 请求A(512) │  空闲(1024)  │ 请求C(2048) │ 空闲  │
└─────────────────────────────────────────────────┘
              ↑ 无法被大请求利用

2. 内存浪费

  • 预分配最大长度,实际使用较少
  • 批处理时不同请求长度不一致
  • 动态增长的序列需要频繁重分配

PagedAttention:解决内存碎片化

vLLM 提出的 PagedAttention 将 KV Cache 分成固定大小的块(block),类似操作系统的虚拟内存管理。

┌─────────────────────────────────────────────────────────────┐
│                  PagedAttention 原理                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   物理块 (固定大小,如16 tokens):                            │
│   ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐        │
│   │Block│ │Block│ │Block│ │Block│ │Block│ │Block│        │
│   │  0  │ │  1  │ │  2  │ │  3  │ │  4  │ │  5  │        │
│   └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘        │
│                                                             │
│   逻辑到物理映射:                                           │
│   ┌────────────────────────────────────────┐               │
│   │ 请求A: Block 0 → Block 2 → Block 5     │               │
│   │ 请求B: Block 1 → Block 3               │               │
│   │ 请求C: Block 4                         │               │
│   └────────────────────────────────────────┘               │
│                                                             │
│   优势:                                                     │
│   1. 消除外部碎片                                           │
│   2. 支持共享(Copy-on-Write)                               │
│   3. 动态分配,按需使用                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

PagedAttention 实现

# paged_attention_demo.py
class PagedAttentionKVCache:
    """PagedAttention KV Cache 实现"""
    
    def __init__(
        self,
        num_blocks: int,
        block_size: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype = torch.float16
    ):
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # 预分配所有块
        self.k_cache = torch.zeros(
            (num_blocks, block_size, num_heads, head_size),
            dtype=dtype, device="cuda"
        )
        self.v_cache = torch.zeros(
            (num_blocks, block_size, num_heads, head_size),
            dtype=dtype, device="cuda"
        )
        
        # 块分配器
        self.block_allocator = BlockAllocator(num_blocks)
        
        # 序列到块的映射
        self.seq_to_blocks: Dict[int, List[int]] = {}
    
    def allocate(self, seq_id: int, num_tokens: int) -> List[int]:
        """为序列分配块"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        
        blocks = self.block_allocator.allocate(num_blocks_needed)
        self.seq_to_blocks[seq_id] = blocks
        
        return blocks
    
    def get_kv_cache(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取序列的 KV Cache"""
        blocks = self.seq_to_blocks.get(seq_id, [])
        
        if not blocks:
            return None, None
        
        # 收集所有块
        k_blocks = [self.k_cache[block_id] for block_id in blocks]
        v_blocks = [self.v_cache[block_id] for block_id in blocks]
        
        # 拼接
        k_cache = torch.cat(k_blocks, dim=0)
        v_cache = torch.cat(v_blocks, dim=0)
        
        return k_cache, v_cache
    
    def free(self, seq_id: int):
        """释放序列的块"""
        blocks = self.seq_to_blocks.pop(seq_id, [])
        self.block_allocator.free(blocks)


class BlockAllocator:
    """块分配器"""
    
    def __init__(self, num_blocks: int):
        self.num_blocks = num_blocks
        self.free_blocks = list(range(num_blocks))
        self.used_blocks = set()
    
    def allocate(self, num_blocks: int) -> List[int]:
        """分配块"""
        if len(self.free_blocks) < num_blocks:
            raise MemoryError("没有足够的空闲块")
        
        allocated = self.free_blocks[:num_blocks]
        self.free_blocks = self.free_blocks[num_blocks:]
        self.used_blocks.update(allocated)
        
        return allocated
    
    def free(self, blocks: List[int]):
        """释放块"""
        for block in blocks:
            if block in self.used_blocks:
                self.used_blocks.remove(block)
                self.free_blocks.append(block)

KV Cache 压缩技术

1. 滑动窗口压缩

# sliding_window_compression.py
class SlidingWindowKVCache:
    """滑动窗口 KV Cache"""
    
    def __init__(self, window_size: int = 4096):
        self.window_size = window_size
        self.k_cache = []
        self.v_cache = []
    
    def update(self, new_k: torch.Tensor, new_v: torch.Tensor):
        """更新 Cache,保持窗口大小"""
        self.k_cache.append(new_k)
        self.v_cache.append(new_v)
        
        # 只保留窗口内的 tokens
        if len(self.k_cache) > self.window_size:
            self.k_cache = self.k_cache[-self.window_size:]
            self.v_cache = self.v_cache[-self.window_size:]
    
    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取当前 Cache"""
        if not self.k_cache:
            return None, None
        
        k = torch.cat(self.k_cache, dim=-2)
        v = torch.cat(self.v_cache, dim=-2)
        
        return k, v

2. H2O(Heavy Hitter Oracle)

保留最重要的 tokens(Heavy Hitters),丢弃不重要的。

# h2o_compression.py
class H2OKVCache:
    """H2O KV Cache 压缩"""
    
    def __init__(
        self,
        heavy_budget: int = 256,    # 保留的 heavy hitters 数量
        recent_budget: int = 128,   # 保留的最近 tokens 数量
    ):
        self.heavy_budget = heavy_budget
        self.recent_budget = recent_budget
        
        self.k_cache = None
        self.v_cache = None
        self.attention_scores = []
    
    def update(
        self,
        new_k: torch.Tensor,
        new_v: torch.Tensor,
        attention_scores: torch.Tensor
    ):
        """更新 Cache 并压缩"""
        
        if self.k_cache is None:
            self.k_cache = new_k
            self.v_cache = new_v
            self.attention_scores = [attention_scores.mean(dim=1)]
            return
        
        # 追加新 tokens
        self.k_cache = torch.cat([self.k_cache, new_k], dim=-2)
        self.v_cache = torch.cat([self.v_cache, new_v], dim=-2)
        self.attention_scores.append(attention_scores.mean(dim=1))
        
        # 压缩
        seq_len = self.k_cache.size(-2)
        if seq_len > self.heavy_budget + self.recent_budget:
            self._compress()
    
    def _compress(self):
        """执行压缩"""
        seq_len = self.k_cache.size(-2)
        
        # 计算累计注意力分数
        cumulative_scores = torch.stack(self.attention_scores, dim=-1).sum(dim=-1)
        
        # 保留最近的 tokens
        recent_k = self.k_cache[..., -self.recent_budget:, :]
        recent_v = self.v_cache[..., -self.recent_budget:, :]
        recent_scores = cumulative_scores[..., -self.recent_budget:]
        
        # 在非最近 tokens 中选择 heavy hitters
        if seq_len > self.recent_budget:
            old_k = self.k_cache[..., :-self.recent_budget, :]
            old_v = self.v_cache[..., :-self.recent_budget, :]
            old_scores = cumulative_scores[..., :-self.recent_budget]
            
            # 选择 top-k
            _, top_indices = torch.topk(
                old_scores,
                min(self.heavy_budget, old_scores.size(-1)),
                dim=-1
            )
            
            # 收集 heavy hitters
            heavy_k = torch.gather(old_k, -2, top_indices.unsqueeze(-1).expand(-1, -1, -1, old_k.size(-1)))
            heavy_v = torch.gather(old_v, -2, top_indices.unsqueeze(-1).expand(-1, -1, -1, old_v.size(-1)))
            
            # 合并
            self.k_cache = torch.cat([heavy_k, recent_k], dim=-2)
            self.v_cache = torch.cat([heavy_v, recent_v], dim=-2)

3. 量化压缩

# kv_cache_quantization.py
class KVCacheQuantizer:
    """KV Cache 量化"""
    
    def __init__(self, bits: int = 8):
        self.bits = bits
        self.qmin = -(2 ** (bits - 1))
        self.qmax = 2 ** (bits - 1) - 1
    
    def quantize(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """量化张量"""
        # 计算缩放因子和零点
        min_val = tensor.min(dim=-1, keepdim=True)[0]
        max_val = tensor.max(dim=-1, keepdim=True)[0]
        
        scale = (max_val - min_val) / (self.qmax - self.qmin)
        zero_point = self.qmin - min_val / scale
        
        # 量化
        quantized = torch.clamp(
            torch.round(tensor / scale + zero_point),
            self.qmin, self.qmax
        ).to(torch.int8)
        
        return quantized, scale, zero_point
    
    def dequantize(
        self,
        quantized: torch.Tensor,
        scale: torch.Tensor,
        zero_point: torch.Tensor
    ) -> torch.Tensor:
        """反量化"""
        return (quantized.float() - zero_point) * scale


class QuantizedKVCache:
    """量化 KV Cache"""
    
    def __init__(self, bits: int = 8):
        self.quantizer = KVCacheQuantizer(bits)
        self.quantized_k = None
        self.quantized_v = None
        self.k_scale = None
        self.k_zero_point = None
        self.v_scale = None
        self.v_zero_point = None
    
    def update(self, k: torch.Tensor, v: torch.Tensor):
        """更新并量化"""
        self.quantized_k, self.k_scale, self.k_zero_point = self.quantizer.quantize(k)
        self.quantized_v, self.v_scale, self.v_zero_point = self.quantizer.quantize(v)
    
    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取反量化的 KV"""
        k = self.quantizer.dequantize(
            self.quantized_k, self.k_scale, self.k_zero_point
        )
        v = self.quantizer.dequantize(
            self.quantized_v, self.v_scale, self.v_zero_point
        )
        return k, v

Continuous Batching

Continuous Batching(连续批处理)是 vLLM 的另一个核心优化,允许在批次处理过程中动态添加新请求。

# continuous_batching.py
class ContinuousBatcher:
    """连续批处理器"""
    
    def __init__(
        self,
        max_batch_size: int = 256,
        max_seq_len: int = 4096,
    ):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        
        self.waiting_queue = []  # 等待队列
        self.running_batch = []  # 运行中的批次
    
    def add_request(self, request: dict):
        """添加新请求"""
        self.waiting_queue.append(request)
    
    def schedule(self) -> List[dict]:
        """调度请求"""
        # 尝试将等待队列的请求加入运行批次
        while (
            self.waiting_queue and
            len(self.running_batch) < self.max_batch_size
        ):
            request = self.waiting_queue.pop(0)
            self.running_batch.append(request)
        
        return self.running_batch
    
    def step(self, model, tokenizer):
        """执行一步推理"""
        batch = self.schedule()
        
        if not batch:
            return []
        
        # 准备输入
        inputs = self._prepare_inputs(batch)
        
        # 推理
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 更新每个请求的状态
        completed = []
        for i, request in enumerate(batch):
            request["tokens"].append(outputs.logits[i].argmax(dim=-1).item())
            
            # 检查是否完成
            if self._is_complete(request):
                completed.append(request)
                self.running_batch.remove(request)
        
        return completed
    
    def _prepare_inputs(self, batch: List[dict]) -> dict:
        """准备批次输入"""
        # 使用 PagedAttention 的块管理
        # 处理不同长度的序列
        pass
    
    def _is_complete(self, request: dict) -> bool:
        """检查请求是否完成"""
        # 检查是否生成了结束符或达到最大长度
        pass

vLLM 中的 KV Cache 优化

1. 配置优化

# vllm_kv_cache_config.py
from vllm import LLM, SamplingParams

# 优化 KV Cache 配置
llm = LLM(
    model="meta-llama/Llama-2-7b",
    
    # KV Cache 相关配置
    gpu_memory_utilization=0.9,  # GPU 内存利用率
    max_model_len=4096,          # 最大序列长度
    max_num_seqs=256,            # 最大并发序列数
    max_num_batched_tokens=4096, # 最大批处理 token 数
    
    # 块配置
    block_size=16,               # 块大小(默认16)
    
    # 交换空间(用于 CPU offload)
    swap_space=4,                # GB
    
    # 注意力后端
    attention_backend="FLASH_ATTN",  # 或 "XFORMERS"
)

# 采样参数
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=512,
)

2. 内存分析

# analyze_kv_cache_memory.py
def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_size: int = 2  # FP16
):
    """计算 KV Cache 内存占用"""
    
    num_layers = model_config["num_hidden_layers"]
    num_heads = model_config["num_attention_heads"]
    head_dim = model_config["hidden_size"] // num_heads
    
    # 每个 token 的 KV Cache 大小
    bytes_per_token = 2 * num_layers * num_heads * head_dim * dtype_size
    
    # 总内存
    total_memory = bytes_per_token * seq_len * batch_size
    
    print(f"模型层数: {num_layers}")
    print(f"注意力头数: {num_heads}")
    print(f"头维度: {head_dim}")
    print(f"每个token占用: {bytes_per_token / 1024:.2f} KB")
    print(f"序列长度: {seq_len}")
    print(f"批大小: {batch_size}")
    print(f"总内存: {total_memory / 1024**3:.2f} GB")
    
    return total_memory

# 示例
model_config = {
    "num_hidden_layers": 32,
    "num_attention_heads": 32,
    "hidden_size": 4096,
}

calculate_kv_cache_memory(model_config, batch_size=16, seq_len=4096)

性能对比

优化技术内存节省吞吐量提升延迟降低实现复杂度
PagedAttention20-40%2-4x30%
Continuous Batching-5-10x-
KV Cache 量化 (INT8)50%-5%
H2O 压缩50-80%-10%
滑动窗口70-90%-20%

最佳实践

1. 选择合适的块大小

# block_size_selection.py
def select_block_size(seq_len_distribution: List[int]) -> int:
    """根据序列长度分布选择最优块大小"""
    
    # 常见块大小
    candidates = [8, 16, 32]
    
    best_size = 16
    min_waste = float('inf')
    
    for block_size in candidates:
        waste = sum(
            (block_size - (seq_len % block_size)) % block_size
            for seq_len in seq_len_distribution
        )
        
        if waste < min_waste:
            min_waste = waste
            best_size = block_size
    
    return best_size

2. 动态内存管理

# dynamic_memory_management.py
class DynamicMemoryManager:
    """动态内存管理器"""
    
    def __init__(self, total_gpu_memory: int):
        self.total_memory = total_gpu_memory
        self.used_memory = 0
        self.kv_cache_memory = 0
    
    def can_allocate(self, num_tokens: int, bytes_per_token: int) -> bool:
        """检查是否可以分配"""
        required = num_tokens * bytes_per_token
        return (self.used_memory + required) < self.total_memory * 0.9
    
    def allocate(self, num_tokens: int, bytes_per_token: int):
        """分配内存"""
        self.kv_cache_memory += num_tokens * bytes_per_token
        self.used_memory += num_tokens * bytes_per_token
    
    def get_memory_stats(self) -> dict:
        """获取内存统计"""
        return {
            "total": self.total_memory / 1024**3,
            "used": self.used_memory / 1024**3,
            "kv_cache": self.kv_cache_memory / 1024**3,
            "utilization": self.used_memory / self.total_memory,
        }

3. 监控和调优

# kv_cache_monitoring.py
class KVCacheMonitor:
    """KV Cache 监控器"""
    
    def __init__(self):
        self.metrics = {
            "cache_hit_rate": [],
            "memory_utilization": [],
            "avg_seq_len": [],
            "batch_size": [],
        }
    
    def record_step(
        self,
        cache_hit_rate: float,
        memory_util: float,
        avg_seq_len: float,
        batch_size: int
    ):
        """记录一步的指标"""
        self.metrics["cache_hit_rate"].append(cache_hit_rate)
        self.metrics["memory_utilization"].append(memory_util)
        self.metrics["avg_seq_len"].append(avg_seq_len)
        self.metrics["batch_size"].append(batch_size)
    
    def get_report(self) -> dict:
        """生成报告"""
        import numpy as np
        
        return {
            "avg_cache_hit_rate": np.mean(self.metrics["cache_hit_rate"]),
            "avg_memory_util": np.mean(self.metrics["memory_utilization"]),
            "avg_seq_len": np.mean(self.metrics["avg_seq_len"]),
            "avg_batch_size": np.mean(self.metrics["batch_size"]),
        }

总结

KV Cache 优化是提升 LLM 推理性能的关键:

  1. PagedAttention:解决内存碎片化,支持内存共享
  2. Continuous Batching:提升吞吐量,支持动态批处理
  3. KV Cache 压缩:降低内存占用,支持更长序列
  4. 量化:减少内存和带宽压力

优化建议:

  • 使用 vLLM 的 PagedAttention 和 Continuous Batching
  • 根据序列长度分布选择合适的块大小
  • 内存紧张时考虑 KV Cache 量化或压缩
  • 监控内存使用情况,及时调整配置

相关资源: