什么是 KV Cache?
KV Cache(Key-Value Cache)是 Transformer 模型推理中的关键优化技术,用于缓存之前计算的 Key 和 Value 矩阵,避免重复计算。
┌─────────────────────────────────────────────────────────────┐
│ KV Cache 原理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 没有 KV Cache: │
│ Token 1 → [计算] → Token 2 → [重新计算1+2] → Token 3 │
│ 时间复杂度: O(n³) │
│ │
│ 使用 KV Cache: │
│ Token 1 → [计算] → 缓存(K1,V1) │
│ Token 2 → [使用缓存+计算2] → 缓存(K1,V1,K2,V2) │
│ Token 3 → [使用缓存+计算3] │
│ 时间复杂度: O(n²) │
│ │
│ 内存占用: │
│ - 每层: 2 × num_heads × head_dim × seq_len × batch_size │
│ - FP16: 每个token约 2 × hidden_size bytes │
│ - 7B模型: 每个token约 1-2 MB │
│ │
└─────────────────────────────────────────────────────────────┘
KV Cache 的内存计算公式:
内存(GB) = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × sizeof(dtype)
例如:
- 模型: Llama-2-7B
- 层数: 32
- 头数: 32
- 头维度: 128
- 序列长度: 4096
- 批大小: 1
- 精度: FP16 (2 bytes)
内存 = 2 × 32 × 32 × 128 × 4096 × 1 × 2 / 1e9 = 2.1 GB
KV Cache 的挑战
1. 内存碎片化
传统分配方式:
┌─────────────────────────────────────────────────┐
│ 请求A(512) │ 请求B(1024) │ 请求C(2048) │ 空闲 │
└─────────────────────────────────────────────────┘
↓
请求B完成后:
┌─────────────────────────────────────────────────┐
│ 请求A(512) │ 空闲(1024) │ 请求C(2048) │ 空闲 │
└─────────────────────────────────────────────────┘
↑ 无法被大请求利用
2. 内存浪费
- 预分配最大长度,实际使用较少
- 批处理时不同请求长度不一致
- 动态增长的序列需要频繁重分配
PagedAttention:解决内存碎片化
vLLM 提出的 PagedAttention 将 KV Cache 分成固定大小的块(block),类似操作系统的虚拟内存管理。
┌─────────────────────────────────────────────────────────────┐
│ PagedAttention 原理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 物理块 (固定大小,如16 tokens): │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │Block│ │Block│ │Block│ │Block│ │Block│ │Block│ │
│ │ 0 │ │ 1 │ │ 2 │ │ 3 │ │ 4 │ │ 5 │ │
│ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ │
│ │
│ 逻辑到物理映射: │
│ ┌────────────────────────────────────────┐ │
│ │ 请求A: Block 0 → Block 2 → Block 5 │ │
│ │ 请求B: Block 1 → Block 3 │ │
│ │ 请求C: Block 4 │ │
│ └────────────────────────────────────────┘ │
│ │
│ 优势: │
│ 1. 消除外部碎片 │
│ 2. 支持共享(Copy-on-Write) │
│ 3. 动态分配,按需使用 │
│ │
└─────────────────────────────────────────────────────────────┘
PagedAttention 实现
# paged_attention_demo.py
class PagedAttentionKVCache:
"""PagedAttention KV Cache 实现"""
def __init__(
self,
num_blocks: int,
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype = torch.float16
):
self.block_size = block_size
self.num_blocks = num_blocks
# 预分配所有块
self.k_cache = torch.zeros(
(num_blocks, block_size, num_heads, head_size),
dtype=dtype, device="cuda"
)
self.v_cache = torch.zeros(
(num_blocks, block_size, num_heads, head_size),
dtype=dtype, device="cuda"
)
# 块分配器
self.block_allocator = BlockAllocator(num_blocks)
# 序列到块的映射
self.seq_to_blocks: Dict[int, List[int]] = {}
def allocate(self, seq_id: int, num_tokens: int) -> List[int]:
"""为序列分配块"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
blocks = self.block_allocator.allocate(num_blocks_needed)
self.seq_to_blocks[seq_id] = blocks
return blocks
def get_kv_cache(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取序列的 KV Cache"""
blocks = self.seq_to_blocks.get(seq_id, [])
if not blocks:
return None, None
# 收集所有块
k_blocks = [self.k_cache[block_id] for block_id in blocks]
v_blocks = [self.v_cache[block_id] for block_id in blocks]
# 拼接
k_cache = torch.cat(k_blocks, dim=0)
v_cache = torch.cat(v_blocks, dim=0)
return k_cache, v_cache
def free(self, seq_id: int):
"""释放序列的块"""
blocks = self.seq_to_blocks.pop(seq_id, [])
self.block_allocator.free(blocks)
class BlockAllocator:
"""块分配器"""
def __init__(self, num_blocks: int):
self.num_blocks = num_blocks
self.free_blocks = list(range(num_blocks))
self.used_blocks = set()
def allocate(self, num_blocks: int) -> List[int]:
"""分配块"""
if len(self.free_blocks) < num_blocks:
raise MemoryError("没有足够的空闲块")
allocated = self.free_blocks[:num_blocks]
self.free_blocks = self.free_blocks[num_blocks:]
self.used_blocks.update(allocated)
return allocated
def free(self, blocks: List[int]):
"""释放块"""
for block in blocks:
if block in self.used_blocks:
self.used_blocks.remove(block)
self.free_blocks.append(block)
KV Cache 压缩技术
1. 滑动窗口压缩
# sliding_window_compression.py
class SlidingWindowKVCache:
"""滑动窗口 KV Cache"""
def __init__(self, window_size: int = 4096):
self.window_size = window_size
self.k_cache = []
self.v_cache = []
def update(self, new_k: torch.Tensor, new_v: torch.Tensor):
"""更新 Cache,保持窗口大小"""
self.k_cache.append(new_k)
self.v_cache.append(new_v)
# 只保留窗口内的 tokens
if len(self.k_cache) > self.window_size:
self.k_cache = self.k_cache[-self.window_size:]
self.v_cache = self.v_cache[-self.window_size:]
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取当前 Cache"""
if not self.k_cache:
return None, None
k = torch.cat(self.k_cache, dim=-2)
v = torch.cat(self.v_cache, dim=-2)
return k, v
2. H2O(Heavy Hitter Oracle)
保留最重要的 tokens(Heavy Hitters),丢弃不重要的。
# h2o_compression.py
class H2OKVCache:
"""H2O KV Cache 压缩"""
def __init__(
self,
heavy_budget: int = 256, # 保留的 heavy hitters 数量
recent_budget: int = 128, # 保留的最近 tokens 数量
):
self.heavy_budget = heavy_budget
self.recent_budget = recent_budget
self.k_cache = None
self.v_cache = None
self.attention_scores = []
def update(
self,
new_k: torch.Tensor,
new_v: torch.Tensor,
attention_scores: torch.Tensor
):
"""更新 Cache 并压缩"""
if self.k_cache is None:
self.k_cache = new_k
self.v_cache = new_v
self.attention_scores = [attention_scores.mean(dim=1)]
return
# 追加新 tokens
self.k_cache = torch.cat([self.k_cache, new_k], dim=-2)
self.v_cache = torch.cat([self.v_cache, new_v], dim=-2)
self.attention_scores.append(attention_scores.mean(dim=1))
# 压缩
seq_len = self.k_cache.size(-2)
if seq_len > self.heavy_budget + self.recent_budget:
self._compress()
def _compress(self):
"""执行压缩"""
seq_len = self.k_cache.size(-2)
# 计算累计注意力分数
cumulative_scores = torch.stack(self.attention_scores, dim=-1).sum(dim=-1)
# 保留最近的 tokens
recent_k = self.k_cache[..., -self.recent_budget:, :]
recent_v = self.v_cache[..., -self.recent_budget:, :]
recent_scores = cumulative_scores[..., -self.recent_budget:]
# 在非最近 tokens 中选择 heavy hitters
if seq_len > self.recent_budget:
old_k = self.k_cache[..., :-self.recent_budget, :]
old_v = self.v_cache[..., :-self.recent_budget, :]
old_scores = cumulative_scores[..., :-self.recent_budget]
# 选择 top-k
_, top_indices = torch.topk(
old_scores,
min(self.heavy_budget, old_scores.size(-1)),
dim=-1
)
# 收集 heavy hitters
heavy_k = torch.gather(old_k, -2, top_indices.unsqueeze(-1).expand(-1, -1, -1, old_k.size(-1)))
heavy_v = torch.gather(old_v, -2, top_indices.unsqueeze(-1).expand(-1, -1, -1, old_v.size(-1)))
# 合并
self.k_cache = torch.cat([heavy_k, recent_k], dim=-2)
self.v_cache = torch.cat([heavy_v, recent_v], dim=-2)
3. 量化压缩
# kv_cache_quantization.py
class KVCacheQuantizer:
"""KV Cache 量化"""
def __init__(self, bits: int = 8):
self.bits = bits
self.qmin = -(2 ** (bits - 1))
self.qmax = 2 ** (bits - 1) - 1
def quantize(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""量化张量"""
# 计算缩放因子和零点
min_val = tensor.min(dim=-1, keepdim=True)[0]
max_val = tensor.max(dim=-1, keepdim=True)[0]
scale = (max_val - min_val) / (self.qmax - self.qmin)
zero_point = self.qmin - min_val / scale
# 量化
quantized = torch.clamp(
torch.round(tensor / scale + zero_point),
self.qmin, self.qmax
).to(torch.int8)
return quantized, scale, zero_point
def dequantize(
self,
quantized: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor
) -> torch.Tensor:
"""反量化"""
return (quantized.float() - zero_point) * scale
class QuantizedKVCache:
"""量化 KV Cache"""
def __init__(self, bits: int = 8):
self.quantizer = KVCacheQuantizer(bits)
self.quantized_k = None
self.quantized_v = None
self.k_scale = None
self.k_zero_point = None
self.v_scale = None
self.v_zero_point = None
def update(self, k: torch.Tensor, v: torch.Tensor):
"""更新并量化"""
self.quantized_k, self.k_scale, self.k_zero_point = self.quantizer.quantize(k)
self.quantized_v, self.v_scale, self.v_zero_point = self.quantizer.quantize(v)
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取反量化的 KV"""
k = self.quantizer.dequantize(
self.quantized_k, self.k_scale, self.k_zero_point
)
v = self.quantizer.dequantize(
self.quantized_v, self.v_scale, self.v_zero_point
)
return k, v
Continuous Batching
Continuous Batching(连续批处理)是 vLLM 的另一个核心优化,允许在批次处理过程中动态添加新请求。
# continuous_batching.py
class ContinuousBatcher:
"""连续批处理器"""
def __init__(
self,
max_batch_size: int = 256,
max_seq_len: int = 4096,
):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.waiting_queue = [] # 等待队列
self.running_batch = [] # 运行中的批次
def add_request(self, request: dict):
"""添加新请求"""
self.waiting_queue.append(request)
def schedule(self) -> List[dict]:
"""调度请求"""
# 尝试将等待队列的请求加入运行批次
while (
self.waiting_queue and
len(self.running_batch) < self.max_batch_size
):
request = self.waiting_queue.pop(0)
self.running_batch.append(request)
return self.running_batch
def step(self, model, tokenizer):
"""执行一步推理"""
batch = self.schedule()
if not batch:
return []
# 准备输入
inputs = self._prepare_inputs(batch)
# 推理
with torch.no_grad():
outputs = model(**inputs)
# 更新每个请求的状态
completed = []
for i, request in enumerate(batch):
request["tokens"].append(outputs.logits[i].argmax(dim=-1).item())
# 检查是否完成
if self._is_complete(request):
completed.append(request)
self.running_batch.remove(request)
return completed
def _prepare_inputs(self, batch: List[dict]) -> dict:
"""准备批次输入"""
# 使用 PagedAttention 的块管理
# 处理不同长度的序列
pass
def _is_complete(self, request: dict) -> bool:
"""检查请求是否完成"""
# 检查是否生成了结束符或达到最大长度
pass
vLLM 中的 KV Cache 优化
1. 配置优化
# vllm_kv_cache_config.py
from vllm import LLM, SamplingParams
# 优化 KV Cache 配置
llm = LLM(
model="meta-llama/Llama-2-7b",
# KV Cache 相关配置
gpu_memory_utilization=0.9, # GPU 内存利用率
max_model_len=4096, # 最大序列长度
max_num_seqs=256, # 最大并发序列数
max_num_batched_tokens=4096, # 最大批处理 token 数
# 块配置
block_size=16, # 块大小(默认16)
# 交换空间(用于 CPU offload)
swap_space=4, # GB
# 注意力后端
attention_backend="FLASH_ATTN", # 或 "XFORMERS"
)
# 采样参数
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=512,
)
2. 内存分析
# analyze_kv_cache_memory.py
def calculate_kv_cache_memory(
model_config: dict,
batch_size: int,
seq_len: int,
dtype_size: int = 2 # FP16
):
"""计算 KV Cache 内存占用"""
num_layers = model_config["num_hidden_layers"]
num_heads = model_config["num_attention_heads"]
head_dim = model_config["hidden_size"] // num_heads
# 每个 token 的 KV Cache 大小
bytes_per_token = 2 * num_layers * num_heads * head_dim * dtype_size
# 总内存
total_memory = bytes_per_token * seq_len * batch_size
print(f"模型层数: {num_layers}")
print(f"注意力头数: {num_heads}")
print(f"头维度: {head_dim}")
print(f"每个token占用: {bytes_per_token / 1024:.2f} KB")
print(f"序列长度: {seq_len}")
print(f"批大小: {batch_size}")
print(f"总内存: {total_memory / 1024**3:.2f} GB")
return total_memory
# 示例
model_config = {
"num_hidden_layers": 32,
"num_attention_heads": 32,
"hidden_size": 4096,
}
calculate_kv_cache_memory(model_config, batch_size=16, seq_len=4096)
性能对比
| 优化技术 | 内存节省 | 吞吐量提升 | 延迟降低 | 实现复杂度 |
|---|---|---|---|---|
| PagedAttention | 20-40% | 2-4x | 30% | 中 |
| Continuous Batching | - | 5-10x | - | 高 |
| KV Cache 量化 (INT8) | 50% | - | 5% | 低 |
| H2O 压缩 | 50-80% | - | 10% | 中 |
| 滑动窗口 | 70-90% | - | 20% | 低 |
最佳实践
1. 选择合适的块大小
# block_size_selection.py
def select_block_size(seq_len_distribution: List[int]) -> int:
"""根据序列长度分布选择最优块大小"""
# 常见块大小
candidates = [8, 16, 32]
best_size = 16
min_waste = float('inf')
for block_size in candidates:
waste = sum(
(block_size - (seq_len % block_size)) % block_size
for seq_len in seq_len_distribution
)
if waste < min_waste:
min_waste = waste
best_size = block_size
return best_size
2. 动态内存管理
# dynamic_memory_management.py
class DynamicMemoryManager:
"""动态内存管理器"""
def __init__(self, total_gpu_memory: int):
self.total_memory = total_gpu_memory
self.used_memory = 0
self.kv_cache_memory = 0
def can_allocate(self, num_tokens: int, bytes_per_token: int) -> bool:
"""检查是否可以分配"""
required = num_tokens * bytes_per_token
return (self.used_memory + required) < self.total_memory * 0.9
def allocate(self, num_tokens: int, bytes_per_token: int):
"""分配内存"""
self.kv_cache_memory += num_tokens * bytes_per_token
self.used_memory += num_tokens * bytes_per_token
def get_memory_stats(self) -> dict:
"""获取内存统计"""
return {
"total": self.total_memory / 1024**3,
"used": self.used_memory / 1024**3,
"kv_cache": self.kv_cache_memory / 1024**3,
"utilization": self.used_memory / self.total_memory,
}
3. 监控和调优
# kv_cache_monitoring.py
class KVCacheMonitor:
"""KV Cache 监控器"""
def __init__(self):
self.metrics = {
"cache_hit_rate": [],
"memory_utilization": [],
"avg_seq_len": [],
"batch_size": [],
}
def record_step(
self,
cache_hit_rate: float,
memory_util: float,
avg_seq_len: float,
batch_size: int
):
"""记录一步的指标"""
self.metrics["cache_hit_rate"].append(cache_hit_rate)
self.metrics["memory_utilization"].append(memory_util)
self.metrics["avg_seq_len"].append(avg_seq_len)
self.metrics["batch_size"].append(batch_size)
def get_report(self) -> dict:
"""生成报告"""
import numpy as np
return {
"avg_cache_hit_rate": np.mean(self.metrics["cache_hit_rate"]),
"avg_memory_util": np.mean(self.metrics["memory_utilization"]),
"avg_seq_len": np.mean(self.metrics["avg_seq_len"]),
"avg_batch_size": np.mean(self.metrics["batch_size"]),
}
总结
KV Cache 优化是提升 LLM 推理性能的关键:
- PagedAttention:解决内存碎片化,支持内存共享
- Continuous Batching:提升吞吐量,支持动态批处理
- KV Cache 压缩:降低内存占用,支持更长序列
- 量化:减少内存和带宽压力
优化建议:
- 使用 vLLM 的 PagedAttention 和 Continuous Batching
- 根据序列长度分布选择合适的块大小
- 内存紧张时考虑 KV Cache 量化或压缩
- 监控内存使用情况,及时调整配置
相关资源: