LLM工程化 进阶 缓存策略 语义缓存 KV Cache 成本优化

LLM缓存策略设计:从语义缓存到分层缓存体系

AIEng Hub
阅读约 20 分钟

LLM缓存策略设计:从语义缓存到分层缓存体系

缓存是降低 LLM 应用成本和延迟最有效的手段之一。从简单的 API 结果缓存到 KV Cache 前缀缓存,不同的缓存策略适用于不同的场景。本文将构建一套完整的 LLM 缓存体系,帮助你在实践中最大化缓存收益。

一、为什么需要缓存?

1.1 缓存的收益矩阵

缓存类型延迟降低成本节省实现复杂度适用场景
精确结果缓存80-95%80-95%重复问题、FAQ
语义缓存60-80%60-80%相似问题聚类
Prompt 缓存50-70%0%系统 prompt 共享
KV Cache 前缀缓存30-50%0%长 prompt 复用
分层缓存70-90%60-80%综合场景

1.2 典型的 Token 浪费模式

无缓存时(浪费严重):

请求 1: "什么是RAG?" → API → "RAG是检索增强..."
请求 2: "请解释RAG是什么" → API → "RAG(检索增强生成)..."  ← 几乎一样的回答
请求 3: "RAG的定义" → API → "检索增强生成(RAG)..."  ← 再次浪费
请求 4: "RAG vs 微调" → API → "RAG和微调的主要区别..."  ← 相关的上下文

4 次请求,4 次 API 调用,3 次可以缓存避免

二、精确结果缓存

2.1 基础实现

最简单的缓存策略:以 prompt 原文为 key 缓存结果。

import hashlib
import redis
from typing import Optional

class ExactCache:
    """精确命中缓存"""
    
    def __init__(self, redis_url="redis://localhost:6379", ttl=3600):
        self.redis = redis.from_url(redis_url)
        self.ttl = ttl  # 缓存有效期(秒)
        self.hits = 0
        self.misses = 0
    
    def _make_key(self, messages, model, params) -> str:
        """生成缓存 key"""
        content = str(messages) + model + str(params)
        return f"llm_cache:exact:{hashlib.md5(content.encode()).hexdigest()}"
    
    def get(self, messages, model="gpt-4o-mini", **params) -> Optional[str]:
        key = self._make_key(messages, model, params)
        result = self.redis.get(key)
        
        if result:
            self.hits += 1
            return result.decode()
        
        self.misses += 1
        return None
    
    def set(self, messages, response, model="gpt-4o-mini", **params):
        key = self._make_key(messages, model, params)
        self.redis.setex(key, self.ttl, response)
    
    @property
    def hit_rate(self):
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0

2.2 适用场景

场景命中率说明
FAQ 系统60-80%用户反复询问相同问题
API 文档查询50-70%固定的知识点查询
代码片段生成40-60%常见代码模式
数据转换30-50%固定的格式转换
创意生成5-10%几乎无重复

三、语义缓存

3.1 核心设计

通过 embedding 相似度匹配,即使问题不完全相同也能命中。

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI

class SemanticCache:
    """
    语义缓存:基于 embedding 相似度的缓存
    
    流程:
    1. 用户查询 → 生成 embedding
    2. 在缓存中找到最相似的已缓存查询
    3. 相似度 > 阈值 → 返回缓存结果
    4. 否则 → API 调用 → 缓存新结果
    """
    
    def __init__(self, threshold=0.92, max_cache_size=10000):
        self.threshold = threshold
        self.max_cache_size = max_cache_size
        self.client = OpenAI()
        
        # 缓存存储
        self.cache = {}  # 嵌入向量 → 结果
        self.queries = []  # 查询文本
        self.embeddings = []  # 对应的嵌入向量
    
    def _get_embedding(self, text: str) -> np.ndarray:
        response = self.client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return np.array(response.data[0].embedding)
    
    def search_cache(self, query: str) -> Optional[str]:
        """在缓存中搜索语义相似的结果"""
        query_emb = self._get_embedding(query)
        
        if len(self.embeddings) == 0:
            return None
        
        # 计算所有缓存项的相似度
        similarities = cosine_similarity(
            [query_emb], 
            self.embeddings
        )[0]
        
        best_idx = np.argmax(similarities)
        best_score = similarities[best_idx]
        
        if best_score >= self.threshold:
            cache_key = self.queries[best_idx]
            return self.cache[cache_key]
        
        return None
    
    def add_to_cache(self, query: str, result: str):
        """添加新结果到缓存"""
        embedding = self._get_embedding(query)
        
        if len(self.queries) >= self.max_cache_size:
            # LRU 淘汰
            oldest = self.queries.pop(0)
            self.embeddings.pop(0)
            del self.cache[oldest]
        
        self.queries.append(query)
        self.embeddings.append(embedding)
        self.cache[query] = result

3.2 相似度阈值调优

def find_optimal_threshold(test_queries, sem_cache, exact_cache):
    """
    通过实验找到最优相似度阈值
    """
    thresholds = [0.80, 0.85, 0.88, 0.90, 0.92, 0.95, 0.98]
    
    for t in thresholds:
        sem_cache.threshold = t
        sem_hits = 0
        false_positives = 0
        total = len(test_queries)
        
        for query in test_queries:
            sem_result = sem_cache.search_cache(query)
            exact_result = exact_cache.get(query)
            
            if sem_result:
                sem_hits += 1
                # 如果语义缓存命中但精确缓存未命中
                # 需要人工评估是否有误报
                if not exact_result:
                    false_positives += 1  # 需要人工标注
        
        print(f"阈值 {t:.2f}: 命中率 {sem_hits/total:.1%}, "
              f"潜在误报 {false_positives}")
阈值命中率质量保障推荐场景
0.8545%宽松非关键场景
0.9030%中等通用场景
0.9518%严格关键业务
0.988%最严格金融/医疗

四、KV Cache 前缀缓存

4.1 原理

vLLM 等推理引擎支持 prefix caching:当多个请求共享相同的 prompt 前缀时,KV Cache 可以复用。

请求 A: "你是一个AI助手。请回答:什么是...?"
         └─────── 共享前缀 ────────┘
请求 B: "你是一个AI助手。请回答:如何...?"
         └─────── 共享前缀 ────────┘

KV Cache 共享:
┌─────────────────────────────────────────┐
│  ████████████████████░░░░░░░░░░░░░░░░░░░│
│  ↑ 共享前缀 (80 tokens)   ↑ 独有部分     │
│  无需重新计算              只需计算后面   │
└─────────────────────────────────────────┘

加速效果:TTFT 降低 30-50%

4.2 配置与使用

# vLLM 启用 prefix caching
vllm serve meta-llama/Llama-2-7b-hf \
    --enable-prefix-caching \
    --max-model-len 8192 \
    --gpu-memory-utilization 0.90

# SGLang 自动支持 RadixAttention(更高效的前缀缓存)
python -m sglang.launch_server \
    --model meta-llama/Llama-2-7b-hf \
    --context-length 8192

4.3 缓存命中分析

场景共享前缀缓存收益
相同 system prompt100%TTFT -40%
相同 few-shot 示例70-90%TTFT -35%
多轮对话相同历史60-80%TTFT -30%
结构化输出格式20-30%TTFT -15%

五、分层缓存架构

5.1 设计

class LayeredCache:
    """
    分层缓存架构
    
    L1: 精确匹配 (毫秒级)
    L2: 语义匹配 (10ms级, 需要 embedding)
    L3: KV Cache 前缀匹配 (服务端)
    L4: API 调用 (原始延迟)
    """
    
    def __init__(self):
        self.l1_exact = ExactCache(ttl=3600)     # 1小时
        self.l2_semantic = SemanticCache(threshold=0.92)  # 较长 TTL
        self.metrics = {
            'l1_hits': 0, 'l1_misses': 0,
            'l2_hits': 0, 'l2_misses': 0,
        }
    
    def get_response(self, messages, model="gpt-4o-mini", **params):
        """分层查询"""
        
        # L1: 精确匹配(最快)
        cached = self.l1_exact.get(messages, model, **params)
        if cached:
            self.metrics['l1_hits'] += 1
            return cached
        self.metrics['l1_misses'] += 1
        
        # L2: 语义匹配(中等速度)
        user_query = messages[-1]['content'] if messages else ""
        cached = self.l2_semantic.search_cache(user_query)
        if cached:
            self.metrics['l2_hits'] += 1
            return cached
        self.metrics['l2_misses'] += 1
        
        # L3+4: API 调用(最慢)
        # 如果服务端支持 KV Cache 前缀缓存,自动生效
        
        # 实际调用 API
        response = self._call_api(messages, model, **params)
        
        # 回填缓存
        self.l1_exact.set(messages, response, model, **params)
        self.l2_semantic.add_to_cache(user_query, response)
        
        return response

5.2 缓存分层效果

请求到达

    ├── L1 精确缓存 ──── 命中? ──→  5ms 响应 ✓
    │      │
    │     未命中
    │      ▼
    ├── L2 语义缓存 ──── 命中? ──→ 50ms 响应 ✓
    │      │
    │     未命中
    │      ▼
    ├── L3 KV Cache 前缀 ── 命中? ──→ 500ms TTFT
    │      │
    │     未命中
    │      ▼
    └── L4 完整 API 调用 ────→ 2000ms 完整响应

分层命中分布(典型场景):
- L1: 30%
- L2: 20%
- L3+4: 50%
- 综合加速: 节省约 60-70% 的完整 API 调用

六、缓存失效与更新

6.1 失效策略

策略适用场景实现复杂度数据新鲜度
TTL 过期通用场景中等
主动失效数据变更时
滑动窗口热点数据
版本标记多版本共存
class CacheInvalidation:
    """
    智能缓存失效策略
    """
    
    def __init__(self, default_ttl=3600):
        self.default_ttl = default_ttl
        
    def get_ttl(self, query_type, user_intent, data_freshness_req):
        """根据场景动态设置 TTL"""
        ttl_map = {
            'knowledge_faq': 86400 * 7,          # 7天
            'code_snippet': 86400 * 30,           # 30天
            'product_pricing': 3600,               # 1小时
            'real_time_data': 60,                  # 1分钟
            'creative_content': 0,                 # 不缓存
        }
        
        base_ttl = ttl_map.get(query_type, self.default_ttl)
        
        # 根据新鲜度要求调整
        if data_freshness_req == 'strict':
            base_ttl = min(base_ttl, 300)  # 最多5分钟
        elif data_freshness_req == 'loose':
            base_ttl = base_ttl * 3  # 3倍
        
        return base_ttl

6.2 缓存预热

def warm_up_cache(cache, frequent_queries, responses):
    """
    系统启动时预热高频查询到缓存
    """
    for query, response in zip(frequent_queries, responses):
        cache.l1_exact.set(query, response, "gpt-4o-mini")
        cache.l2_semantic.add_to_cache(query, response)

七、监控与调优

7.1 关键指标

指标计算方式健康值预警值
综合命中率总命中/总请求> 50%< 20%
L1 命中率L1命中/总请求> 30%< 10%
缓存占用缓存条目数< 80% 容量> 95%
平均响应加速(非缓存-缓存)/原始> 60%< 20%

7.2 常见问题

问题表现原因解决方案
低命中率缓存基本无效果查询多样性过高降低语义阈值;增加缓存容量
高误报率返回不相关内容语义阈值过低提高阈值;添加上下文过滤
缓存放大占用过多内存缓存过大LRU 淘汰;缩短 TTL
数据陈旧返回过期信息TTL过长缩短 TTL;主动失效

总结

高效缓存策略可以将 LLM 应用的成本降低 60-80%,同时显著改善用户体验。关键在于:

  1. 分层设计:精确缓存做快速命中,语义缓存扩大覆盖,KV Cache 前缀缓存加速推理
  2. 动态配置:根据业务场景设置不同 TTL,缓存失效策略与数据新鲜度匹配
  3. 持续监控:命中率、误报率、缓存放大效应需要持续优化
  4. 冷启动策略:通过预热和渐进式缓存上线,避免冷缓存期间的全量 API 调用