AI Agent 进阶 AI Agent 记忆管理 向量数据库 记忆压缩

AI Agent 记忆管理策略:从理论到实践

AIEng Hub
阅读约 35 分钟

引言

记忆是 AI Agent 的”灵魂”。一个具备良好记忆管理的 Agent 能够记住用户偏好、学习历史经验、保持对话连贯性,从而提供更加个性化和高效的服务。本文将深入探讨 AI Agent 记忆管理的核心策略,从短期记忆到长期记忆,从存储到检索,从压缩到优化,帮助你构建生产级的记忆系统。

┌─────────────────────────────────────────────────────────────────┐
│                    AI Agent 记忆管理全景图                       │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   ┌──────────────┐    ┌──────────────┐    ┌──────────────┐     │
│   │   短期记忆    │───→│   记忆处理    │───→│   长期记忆    │     │
│   │              │    │              │    │              │     │
│   │ • 对话历史   │    │ • 摘要提取   │    │ • 向量存储   │     │
│   │ • 上下文窗口 │    │ • 重要性评分 │    │ • 知识图谱   │     │
│   │ • 临时状态   │    │ • 压缩编码   │    │ • 数据库存储 │     │
│   └──────────────┘    └──────────────┘    └──────────────┘     │
│          ↑                                            │         │
│          └────────────────────────────────────────────┘         │
│                         检索与召回                              │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

一、短期记忆管理策略

1.1 对话历史的智能管理

短期记忆主要管理当前对话的上下文,核心挑战在于如何在有限的 Token 预算内保留最有价值的信息。

from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
import uuid
import tiktoken

@dataclass
class Message:
    """消息数据类"""
    role: str  # "system", "user", "assistant", "tool"
    content: str
    metadata: Dict[str, Any] = field(default_factory=dict)
    timestamp: datetime = field(default_factory=datetime.now)
    message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    importance_score: float = 0.5  # 重要性评分
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "role": self.role,
            "content": self.content,
            "metadata": self.metadata,
            "timestamp": self.timestamp.isoformat(),
            "message_id": self.message_id,
            "importance_score": self.importance_score
        }

class SmartConversationBuffer:
    """智能对话缓冲区 - 支持重要性评分和动态裁剪"""
    
    def __init__(self, 
                 max_tokens: int = 4000,
                 model: str = "gpt-4",
                 importance_threshold: float = 0.3):
        self.max_tokens = max_tokens
        self.model = model
        self.importance_threshold = importance_threshold
        self.messages: List[Message] = []
        self.encoding = tiktoken.encoding_for_model(model)
        
        # 重要性评分函数
        self.importance_scorer: Optional[Callable] = None
    
    def count_tokens(self, text: str) -> int:
        """计算文本的 token 数"""
        return len(self.encoding.encode(text))
    
    def calculate_message_tokens(self, message: Message) -> int:
        """计算单条消息的 token 开销"""
        # 角色标记 + 内容 + 格式开销
        return self.count_tokens(message.content) + 4
    
    def set_importance_scorer(self, scorer: Callable[[Message], float]):
        """设置重要性评分函数"""
        self.importance_scorer = scorer
    
    def add_message(self, role: str, content: str, 
                    metadata: Dict = None,
                    importance: float = None) -> Message:
        """添加消息到缓冲区"""
        message = Message(
            role=role, 
            content=content, 
            metadata=metadata or {}
        )
        
        # 计算重要性评分
        if importance is not None:
            message.importance_score = importance
        elif self.importance_scorer:
            message.importance_score = self.importance_scorer(message)
        
        # 系统消息自动设为高重要性
        if role == "system":
            message.importance_score = 1.0
        
        self.messages.append(message)
        
        # 触发裁剪
        self._trim_if_needed()
        
        return message
    
    def _trim_if_needed(self) -> None:
        """智能裁剪策略"""
        total_tokens = sum(
            self.calculate_message_tokens(m) for m in self.messages
        )
        
        if total_tokens <= self.max_tokens:
            return
        
        # 策略1: 保留系统消息和高重要性消息
        system_messages = [m for m in self.messages if m.role == "system"]
        high_importance = [
            m for m in self.messages 
            if m.role != "system" and m.importance_score >= 0.8
        ]
        other_messages = [
            m for m in self.messages 
            if m not in system_messages and m not in high_importance
        ]
        
        # 按时间排序其他消息
        other_messages.sort(key=lambda m: m.timestamp)
        
        # 从最早的消息开始删除低重要性消息
        while total_tokens > self.max_tokens and other_messages:
            removed = other_messages.pop(0)
            if removed.importance_score < self.importance_threshold:
                self.messages.remove(removed)
                total_tokens -= self.calculate_message_tokens(removed)
            else:
                # 如果消息重要性高,尝试摘要
                break
        
        # 如果仍然超限,对旧消息进行摘要
        if total_tokens > self.max_tokens:
            self._summarize_old_messages()
    
    def _summarize_old_messages(self) -> None:
        """对旧消息进行摘要(占位符,实际需调用 LLM)"""
        # 保留最近的消息,对旧消息生成摘要
        pass
    
    def get_messages(self, last_n: int = None) -> List[Message]:
        """获取消息列表"""
        if last_n:
            return self.messages[-last_n:]
        return self.messages.copy()
    
    def to_llm_format(self) -> List[Dict[str, str]]:
        """转换为 LLM API 格式"""
        return [
            {"role": msg.role, "content": msg.content}
            for msg in self.messages
        ]
    
    def clear(self) -> None:
        """清空缓冲区"""
        self.messages = []
    
    def get_stats(self) -> Dict[str, Any]:
        """获取缓冲区统计信息"""
        total_tokens = sum(
            self.calculate_message_tokens(m) for m in self.messages
        )
        return {
            "message_count": len(self.messages),
            "total_tokens": total_tokens,
            "max_tokens": self.max_tokens,
            "utilization": total_tokens / self.max_tokens,
            "avg_importance": sum(m.importance_score for m in self.messages) / len(self.messages) if self.messages else 0
        }

1.2 滑动窗口与摘要结合策略

当对话历史过长时,我们需要一种混合策略:保留最近完整对话 + 历史摘要。

from langchain_openai import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
import asyncio

class HybridMemory:
    """混合记忆策略:滑动窗口 + 摘要"""
    
    def __init__(self, 
                 llm: ChatOpenAI = None,
                 max_tokens: int = 6000,
                 recent_messages_limit: int = 6,
                 summary_token_limit: int = 1000):
        self.llm = llm or ChatOpenAI(model="gpt-3.5-turbo")
        self.max_tokens = max_tokens
        self.recent_limit = recent_messages_limit
        self.summary_limit = summary_token_limit
        
        self.all_messages: List[Message] = []
        self.summary = ""
        self.recent_buffer: List[Message] = []
    
    async def add_message(self, role: str, content: str, 
                         metadata: Dict = None) -> None:
        """添加消息"""
        message = Message(role=role, content=content, metadata=metadata or {})
        self.all_messages.append(message)
        self.recent_buffer.append(message)
        
        # 维护近期消息窗口
        if len(self.recent_buffer) > self.recent_limit * 2:
            # 触发摘要更新
            await self._update_summary()
            # 保留最近的对话
            self.recent_buffer = self.recent_buffer[-self.recent_limit:]
    
    async def _update_summary(self) -> None:
        """更新对话摘要"""
        # 获取需要摘要的消息
        messages_to_summarize = self.recent_buffer[:-self.recent_limit]
        if not messages_to_summarize:
            return
        
        conversation_text = "\n".join([
            f"{msg.role}: {msg.content}"
            for msg in messages_to_summarize
        ])
        
        prompt = f"""请将以下对话内容整合到现有摘要中。

当前摘要:
{self.summary}

新对话内容:
{conversation_text}

要求:
1. 保留关键信息和用户偏好
2. 合并重复或相关的内容
3. 控制在 {self.summary_limit} 字以内
4. 保持时间顺序

更新后的摘要:"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            self.summary = response.content.strip()
        except Exception as e:
            print(f"摘要生成失败: {e}")
    
    def get_context(self) -> str:
        """获取完整上下文"""
        parts = []
        
        if self.summary:
            parts.append(f"【历史摘要】\n{self.summary}\n")
        
        if self.recent_buffer:
            parts.append("【近期对话】")
            for msg in self.recent_buffer:
                parts.append(f"{msg.role}: {msg.content}")
        
        return "\n".join(parts)
    
    def to_llm_messages(self) -> List[Dict[str, str]]:
        """转换为 LLM 消息格式"""
        messages = []
        
        # 添加摘要作为系统消息
        if self.summary:
            messages.append({
                "role": "system",
                "content": f"对话历史摘要: {self.summary}"
            })
        
        # 添加近期消息
        for msg in self.recent_buffer:
            messages.append({"role": msg.role, "content": msg.content})
        
        return messages


# 使用示例
async def demo_hybrid_memory():
    memory = HybridMemory()
    
    # 模拟多轮对话
    conversations = [
        ("user", "你好,我想学习 Python 编程"),
        ("assistant", "你好!Python 是一门非常优秀的编程语言,适合初学者。"),
        ("user", "我更喜欢通过项目实战来学习"),
        ("assistant", "很好的学习方式!项目驱动能让你更快掌握实际技能。"),
        ("user", "请推荐一些适合初学者的项目"),
        ("assistant", "对于初学者,我推荐:1. 计算器 2. 待办事项应用 3. 简单爬虫"),
        ("user", "我对 Web 开发比较感兴趣"),
        ("assistant", "Web 开发是 Python 的热门领域,可以学习 Flask 或 Django。"),
        ("user", "Flask 和 Django 哪个更适合新手?"),
        ("assistant", "Flask 更轻量、灵活,适合新手理解 Web 开发原理。"),
    ]
    
    for role, content in conversations:
        await memory.add_message(role, content)
        print(f"[{role}] {content[:50]}...")
    
    print("\n" + "="*60)
    print("最终上下文:")
    print("="*60)
    print(memory.get_context())

# 运行示例
# asyncio.run(demo_hybrid_memory())

1.3 Token 预算分配策略

合理分配 Token 预算是短期记忆管理的关键。

class TokenBudgetManager:
    """Token 预算管理器"""
    
    def __init__(self, total_budget: int = 8000):
        self.total_budget = total_budget
        
        # 默认预算分配
        self.allocations = {
            "system_prompt": 0.15,      # 15% 给系统提示
            "conversation_history": 0.40,  # 40% 给对话历史
            "retrieved_context": 0.30,   # 30% 给检索的上下文
            "user_input": 0.10,         # 10% 给用户输入
            "response_buffer": 0.05     # 5% 给响应缓冲
        }
    
    def set_allocation(self, category: str, percentage: float):
        """设置预算分配比例"""
        if 0 <= percentage <= 1:
            self.allocations[category] = percentage
    
    def get_budget(self, category: str) -> int:
        """获取指定类别的 Token 预算"""
        return int(self.total_budget * self.allocations.get(category, 0))
    
    def optimize_history(self, messages: List[Message], 
                        budget: int) -> List[Message]:
        """在预算内优化历史消息"""
        encoding = tiktoken.encoding_for_model("gpt-4")
        
        # 按重要性排序
        sorted_messages = sorted(
            messages,
            key=lambda m: (m.importance_score, m.timestamp),
            reverse=True
        )
        
        selected = []
        current_tokens = 0
        
        # 优先保留系统消息
        system_msgs = [m for m in sorted_messages if m.role == "system"]
        for msg in system_msgs:
            msg_tokens = len(encoding.encode(msg.content)) + 4
            if current_tokens + msg_tokens <= budget:
                selected.append(msg)
                current_tokens += msg_tokens
        
        # 按时间顺序选择其他高重要性消息
        other_msgs = [m for m in sorted_messages if m.role != "system"]
        other_msgs.sort(key=lambda m: m.timestamp)
        
        for msg in other_msgs:
            msg_tokens = len(encoding.encode(msg.content)) + 4
            if current_tokens + msg_tokens <= budget:
                selected.append(msg)
                current_tokens += msg_tokens
            else:
                break
        
        # 按时间排序返回
        selected.sort(key=lambda m: m.timestamp)
        return selected
    
    def calculate_available_for_response(self, 
                                        used_tokens: int) -> int:
        """计算响应可用的 Token 数"""
        return self.total_budget - used_tokens - 100  # 100 为安全余量

二、长期记忆持久化策略

2.1 分层存储架构

长期记忆需要分层存储,根据访问频率和重要性选择不同的存储介质。

from enum import Enum
from typing import Optional
import json
import sqlite3
from datetime import datetime, timedelta

class MemoryTier(Enum):
    """记忆层级"""
    HOT = "hot"      # 热数据 - 内存
    WARM = "warm"    # 温数据 - Redis
    COLD = "cold"    # 冷数据 - 向量数据库
    ARCHIVE = "archive"  # 归档 - 对象存储

class TieredMemoryStore:
    """分层记忆存储"""
    
    def __init__(self):
        # 热数据缓存(内存)
        self.hot_cache: Dict[str, Dict] = {}
        self.hot_max_size = 100
        
        # 温数据(SQLite 模拟)
        self.warm_db_path = "./warm_memory.db"
        self._init_warm_db()
        
        # 访问统计
        self.access_stats: Dict[str, Dict] = {}
    
    def _init_warm_db(self):
        """初始化温数据数据库"""
        conn = sqlite3.connect(self.warm_db_path)
        cursor = conn.cursor()
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS memories (
                id TEXT PRIMARY KEY,
                content TEXT,
                memory_type TEXT,
                importance_score REAL,
                access_count INTEGER DEFAULT 0,
                last_accessed TIMESTAMP,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                metadata TEXT
            )
        """)
        conn.commit()
        conn.close()
    
    def _determine_tier(self, memory: Dict) -> MemoryTier:
        """根据访问模式确定存储层级"""
        memory_id = memory.get("id")
        stats = self.access_stats.get(memory_id, {})
        
        access_count = stats.get("count", 0)
        last_access = stats.get("last_access")
        importance = memory.get("importance_score", 0.5)
        
        # 高频访问 + 高重要性 = 热数据
        if access_count > 10 and importance > 0.7:
            return MemoryTier.HOT
        
        # 中频访问 = 温数据
        if access_count > 3:
            return MemoryTier.WARM
        
        # 低频访问 = 冷数据
        if access_count > 0:
            return MemoryTier.COLD
        
        # 从未访问 = 归档
        return MemoryTier.ARCHIVE
    
    def store(self, memory_id: str, content: str, 
              memory_type: str = "fact",
              metadata: Dict = None) -> MemoryTier:
        """存储记忆到合适的层级"""
        memory = {
            "id": memory_id,
            "content": content,
            "type": memory_type,
            "metadata": metadata or {},
            "importance_score": metadata.get("importance", 0.5),
            "created_at": datetime.now().isoformat()
        }
        
        tier = self._determine_tier(memory)
        
        if tier == MemoryTier.HOT:
            self._store_hot(memory_id, memory)
        elif tier == MemoryTier.WARM:
            self._store_warm(memory_id, memory)
        else:
            self._store_cold(memory_id, memory)
        
        return tier
    
    def _store_hot(self, memory_id: str, memory: Dict):
        """存储到热缓存"""
        # LRU 淘汰
        if len(self.hot_cache) >= self.hot_max_size:
            oldest = min(self.hot_cache.keys(), 
                        key=lambda k: self.hot_cache[k].get("last_accessed", 0))
            del self.hot_cache[oldest]
        
        self.hot_cache[memory_id] = {
            **memory,
            "last_accessed": datetime.now().isoformat()
        }
    
    def _store_warm(self, memory_id: str, memory: Dict):
        """存储到温数据库"""
        conn = sqlite3.connect(self.warm_db_path)
        cursor = conn.cursor()
        cursor.execute("""
            INSERT OR REPLACE INTO memories 
            (id, content, memory_type, importance_score, metadata)
            VALUES (?, ?, ?, ?, ?)
        """, (
            memory_id,
            memory["content"],
            memory["type"],
            memory["importance_score"],
            json.dumps(memory["metadata"])
        ))
        conn.commit()
        conn.close()
    
    def _store_cold(self, memory_id: str, memory: Dict):
        """存储到冷数据(向量数据库)"""
        # 实际实现中应存入 Chroma/Pinecone 等
        pass
    
    def retrieve(self, memory_id: str) -> Optional[Dict]:
        """检索记忆(自动层级提升)"""
        # 更新访问统计
        if memory_id not in self.access_stats:
            self.access_stats[memory_id] = {"count": 0}
        self.access_stats[memory_id]["count"] += 1
        self.access_stats[memory_id]["last_access"] = datetime.now()
        
        # 按层级查找
        if memory_id in self.hot_cache:
            self.hot_cache[memory_id]["last_accessed"] = datetime.now().isoformat()
            return self.hot_cache[memory_id]
        
        # 查找温数据
        conn = sqlite3.connect(self.warm_db_path)
        cursor = conn.cursor()
        cursor.execute(
            "SELECT * FROM memories WHERE id = ?", (memory_id,)
        )
        row = cursor.fetchone()
        conn.close()
        
        if row:
            memory = {
                "id": row[0],
                "content": row[1],
                "type": row[2],
                "importance_score": row[3],
                "metadata": json.loads(row[7]) if row[7] else {}
            }
            
            # 访问频繁则提升到热缓存
            if self.access_stats[memory_id]["count"] > 5:
                self._store_hot(memory_id, memory)
            
            return memory
        
        return None

2.2 记忆生命周期管理

记忆也有生命周期,需要定期清理过期和低价值的记忆。

class MemoryLifecycleManager:
    """记忆生命周期管理"""
    
    def __init__(self):
        self.ttl_config = {
            "fact": timedelta(days=365),      # 事实长期有效
            "preference": timedelta(days=180), # 偏好半年
            "context": timedelta(days=7),      # 上下文一周
            "temporary": timedelta(hours=1)    # 临时一小时
        }
        
        self.decay_rates = {
            "fact": 0.99,       # 事实几乎不衰减
            "preference": 0.95,  # 偏好缓慢衰减
            "context": 0.80,     # 上下文快速衰减
            "temporary": 0.50    # 临时极速衰减
        }
    
    def calculate_memory_value(self, memory: Dict) -> float:
        """计算记忆当前价值"""
        base_importance = memory.get("importance_score", 0.5)
        memory_type = memory.get("type", "temporary")
        created_at = memory.get("created_at")
        access_count = memory.get("access_count", 0)
        
        if not created_at:
            return base_importance
        
        # 计算时间衰减
        if isinstance(created_at, str):
            created_at = datetime.fromisoformat(created_at)
        
        age = datetime.now() - created_at
        age_days = age.days
        
        decay_rate = self.decay_rates.get(memory_type, 0.9)
        time_decay = decay_rate ** age_days
        
        # 访问频率加成
        access_boost = min(0.3, access_count * 0.01)
        
        # 最终价值
        current_value = base_importance * time_decay + access_boost
        return min(1.0, current_value)
    
    def should_retain(self, memory: Dict) -> bool:
        """判断是否应该保留记忆"""
        memory_type = memory.get("type", "temporary")
        created_at = memory.get("created_at")
        
        if not created_at:
            return True
        
        if isinstance(created_at, str):
            created_at = datetime.fromisoformat(created_at)
        
        # 检查 TTL
        ttl = self.ttl_config.get(memory_type, timedelta(days=30))
        if datetime.now() - created_at > ttl:
            return False
        
        # 检查当前价值
        current_value = self.calculate_memory_value(memory)
        return current_value > 0.1
    
    async def cleanup_memories(self, 
                              memory_store: TieredMemoryStore) -> Dict[str, int]:
        """清理过期记忆"""
        stats = {"removed": 0, "archived": 0, "retained": 0}
        
        # 这里应该遍历所有记忆进行检查
        # 简化示例:仅展示逻辑
        
        return stats

三、向量记忆与语义检索

3.1 向量存储策略

向量记忆是长期记忆的核心,支持语义检索。

import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import hashlib

try:
    import chromadb
    from chromadb.config import Settings
    CHROMADB_AVAILABLE = True
except ImportError:
    CHROMADB_AVAILABLE = False

class VectorMemoryManager:
    """向量记忆管理器"""
    
    def __init__(self, 
                 collection_name: str = "agent_memory",
                 embedding_model = None,
                 persist_directory: str = "./vector_memory"):
        self.collection_name = collection_name
        self.embedding_model = embedding_model
        
        if CHROMADB_AVAILABLE:
            self.client = chromadb.Client(Settings(
                chroma_db_impl="duckdb+parquet",
                persist_directory=persist_directory,
                anonymized_telemetry=False
            ))
            self.collection = self.client.get_or_create_collection(
                name=collection_name,
                metadata={"hnsw:space": "cosine"}
            )
        else:
            # 内存存储回退
            self._memory_store: Dict[str, Dict] = {}
        
        self.cache: Dict[str, np.ndarray] = {}
    
    async def embed_text(self, text: str) -> List[float]:
        """生成文本嵌入"""
        if self.embedding_model:
            # 使用外部 embedding 模型
            embedding = await self.embedding_model.embed_query(text)
            return embedding
        else:
            # 简单的哈希 embedding(仅用于演示)
            return self._simple_embed(text)
    
    def _simple_embed(self, text: str, dim: int = 384) -> List[float]:
        """简单的确定性 embedding"""
        # 使用哈希生成伪随机但确定的向量
        hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
        np.random.seed(hash_val % (2**32))
        vec = np.random.randn(dim)
        # 归一化
        vec = vec / np.linalg.norm(vec)
        return vec.tolist()
    
    async def add_memory(self,
                        content: str,
                        memory_type: str = "fact",
                        metadata: Dict[str, Any] = None) -> str:
        """添加记忆"""
        # 生成唯一 ID
        memory_id = hashlib.md5(
            f"{content}:{datetime.now().isoformat()}".encode()
        ).hexdigest()
        
        # 生成 embedding
        embedding = await self.embed_text(content)
        
        # 构建元数据
        full_metadata = {
            "type": memory_type,
            "timestamp": datetime.now().isoformat(),
            "content_hash": hashlib.md5(content.encode()).hexdigest()[:8],
            **(metadata or {})
        }
        
        if CHROMADB_AVAILABLE:
            self.collection.add(
                ids=[memory_id],
                embeddings=[embedding],
                documents=[content],
                metadatas=[full_metadata]
            )
        else:
            self._memory_store[memory_id] = {
                "id": memory_id,
                "embedding": embedding,
                "content": content,
                "metadata": full_metadata
            }
        
        return memory_id
    
    async def search(self,
                    query: str,
                    top_k: int = 5,
                    memory_type: str = None,
                    min_similarity: float = 0.7) -> List[Dict]:
        """语义搜索记忆"""
        query_embedding = await self.embed_text(query)
        
        results = []
        
        if CHROMADB_AVAILABLE:
            # 构建过滤条件
            where_clause = {}
            if memory_type:
                where_clause["type"] = memory_type
            
            db_results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=top_k * 2,  # 获取更多结果用于过滤
                where=where_clause if where_clause else None
            )
            
            # 格式化结果
            for i in range(len(db_results["ids"][0])):
                distance = db_results["distances"][0][i]
                similarity = 1 - distance  # 余弦距离转相似度
                
                if similarity >= min_similarity:
                    results.append({
                        "id": db_results["ids"][0][i],
                        "content": db_results["documents"][0][i],
                        "similarity": similarity,
                        "metadata": db_results["metadatas"][0][i]
                    })
        else:
            # 内存搜索
            query_vec = np.array(query_embedding)
            for memory_id, memory in self._memory_store.items():
                mem_vec = np.array(memory["embedding"])
                similarity = np.dot(query_vec, mem_vec) / (
                    np.linalg.norm(query_vec) * np.linalg.norm(mem_vec)
                )
                
                if similarity >= min_similarity:
                    if memory_type and memory["metadata"].get("type") != memory_type:
                        continue
                    
                    results.append({
                        "id": memory_id,
                        "content": memory["content"],
                        "similarity": float(similarity),
                        "metadata": memory["metadata"]
                    })
            
            # 按相似度排序
            results.sort(key=lambda x: x["similarity"], reverse=True)
            results = results[:top_k]
        
        return results
    
    async def hybrid_search(self,
                           query: str,
                           keywords: List[str] = None,
                           top_k: int = 5) -> List[Dict]:
        """混合搜索:语义 + 关键词"""
        # 语义搜索
        semantic_results = await self.search(query, top_k=top_k * 2)
        
        # 关键词匹配
        if keywords:
            keyword_scores = {}
            for result in semantic_results:
                content = result["content"].lower()
                score = sum(1 for kw in keywords if kw.lower() in content)
                keyword_scores[result["id"]] = score / len(keywords)
            
            # 混合评分
            for result in semantic_results:
                sem_score = result["similarity"]
                kw_score = keyword_scores.get(result["id"], 0)
                result["hybrid_score"] = 0.7 * sem_score + 0.3 * kw_score
            
            # 按混合评分排序
            semantic_results.sort(key=lambda x: x["hybrid_score"], reverse=True)
        
        return semantic_results[:top_k]
    
    def delete_memory(self, memory_id: str) -> bool:
        """删除记忆"""
        try:
            if CHROMADB_AVAILABLE:
                self.collection.delete(ids=[memory_id])
            else:
                if memory_id in self._memory_store:
                    del self._memory_store[memory_id]
            return True
        except Exception as e:
            print(f"删除记忆失败: {e}")
            return False
    
    async def update_memory(self, 
                           memory_id: str,
                           new_content: str,
                           metadata: Dict = None) -> bool:
        """更新记忆"""
        try:
            # 删除旧记忆
            self.delete_memory(memory_id)
            # 添加新记忆
            await self.add_memory(
                new_content,
                metadata=metadata
            )
            return True
        except Exception as e:
            print(f"更新记忆失败: {e}")
            return False

3.2 上下文检索策略

检索到的记忆需要智能地整合到上下文中。

class ContextAssembler:
    """上下文组装器"""
    
    def __init__(self, max_context_tokens: int = 2000):
        self.max_tokens = max_context_tokens
        self.encoding = tiktoken.encoding_for_model("gpt-4")
    
    def assemble_context(self,
                        query: str,
                        retrieved_memories: List[Dict],
                        current_conversation: List[Message] = None) -> str:
        """组装检索到的记忆为上下文"""
        
        parts = []
        current_tokens = 0
        
        # 1. 添加查询意图说明
        intent_text = f"用户当前问题: {query}\n\n相关历史信息:\n"
        parts.append(intent_text)
        current_tokens += len(self.encoding.encode(intent_text))
        
        # 2. 按相关性和重要性排序记忆
        sorted_memories = sorted(
            retrieved_memories,
            key=lambda m: (m.get("similarity", 0), 
                          m.get("metadata", {}).get("importance", 0.5)),
            reverse=True
        )
        
        # 3. 分类组织记忆
        categories = {
            "fact": [],
            "preference": [],
            "experience": [],
            "other": []
        }
        
        for memory in sorted_memories:
            mem_type = memory.get("metadata", {}).get("type", "other")
            if mem_type in categories:
                categories[mem_type].append(memory)
            else:
                categories["other"].append(memory)
        
        # 4. 按优先级添加各类记忆
        priority_order = ["preference", "fact", "experience", "other"]
        
        for category in priority_order:
            if not categories[category]:
                continue
            
            category_header = f"\n{self._get_category_name(category)}\n"
            if current_tokens + len(self.encoding.encode(category_header)) > self.max_tokens:
                break
            
            parts.append(category_header)
            current_tokens += len(self.encoding.encode(category_header))
            
            for memory in categories[category]:
                content = memory["content"]
                mem_text = f"- {content}\n"
                mem_tokens = len(self.encoding.encode(mem_text))
                
                if current_tokens + mem_tokens > self.max_tokens:
                    break
                
                parts.append(mem_text)
                current_tokens += mem_tokens
        
        return "".join(parts)
    
    def _get_category_name(self, category: str) -> str:
        """获取分类显示名称"""
        names = {
            "fact": "相关事实",
            "preference": "用户偏好",
            "experience": "过往经验",
            "other": "其他信息"
        }
        return names.get(category, "其他")
    
    def create_rag_prompt(self,
                         query: str,
                         context: str,
                         system_prompt: str = None) -> List[Dict[str, str]]:
        """创建 RAG 提示"""
        messages = []
        
        # 系统提示
        sys_prompt = system_prompt or """你是一个智能助手。请基于提供的历史信息回答用户问题。
如果历史信息不足以回答问题,请明确说明。
不要编造不存在的信息。"""
        
        messages.append({"role": "system", "content": sys_prompt})
        
        # 添加上下文作为系统消息的一部分
        full_context = f"{sys_prompt}\n\n{context}"
        messages[0]["content"] = full_context
        
        # 用户查询
        messages.append({"role": "user", "content": query})
        
        return messages

四、记忆压缩与摘要策略

4.1 渐进式摘要

随着对话进行,定期对历史进行摘要。

class ProgressiveSummarizer:
    """渐进式摘要器"""
    
    def __init__(self, llm: ChatOpenAI = None):
        self.llm = llm or ChatOpenAI(model="gpt-3.5-turbo")
        
        # 摘要层级
        self.summary_levels = {
            0: {"trigger": 10, "max_length": 200},    # 10条消息触发
            1: {"trigger": 5, "max_length": 300},     # 5个L0摘要触发
            2: {"trigger": 3, "max_length": 400},     # 3个L1摘要触发
        }
        
        self.messages_buffer: List[Message] = []
        self.summaries: Dict[int, List[str]] = {0: [], 1: [], 2: []}
    
    async def add_message(self, message: Message) -> Optional[str]:
        """添加消息,可能触发摘要"""
        self.messages_buffer.append(message)
        
        # 检查是否需要生成 L0 摘要
        if len(self.messages_buffer) >= self.summary_levels[0]["trigger"]:
            summary = await self._generate_summary(
                self.messages_buffer,
                self.summary_levels[0]["max_length"]
            )
            self.summaries[0].append(summary)
            self.messages_buffer = []
            
            # 检查高层摘要
            await self._check_higher_summaries()
            
            return summary
        
        return None
    
    async def _generate_summary(self, 
                               messages: List[Message],
                               max_length: int) -> str:
        """生成摘要"""
        conversation = "\n".join([
            f"{msg.role}: {msg.content}"
            for msg in messages
        ])
        
        prompt = f"""请将以下对话总结为简洁的摘要,不超过{max_length}字:

对话内容:
{conversation}

要求:
1. 保留关键信息和用户明确表达的偏好
2. 去除寒暄和重复内容
3. 按时间顺序组织
4. 使用第三人称客观描述

摘要:"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            return response.content.strip()[:max_length]
        except Exception as e:
            print(f"摘要生成失败: {e}")
            return "摘要生成失败"
    
    async def _check_higher_summaries(self):
        """检查并生成高层摘要"""
        for level in [1, 2]:
            prev_level = level - 1
            if len(self.summaries[prev_level]) >= self.summary_levels[level]["trigger"]:
                # 将下层摘要作为消息处理
                summary_messages = [
                    Message(role="assistant", content=s)
                    for s in self.summaries[prev_level]
                ]
                
                higher_summary = await self._generate_summary(
                    summary_messages,
                    self.summary_levels[level]["max_length"]
                )
                
                self.summaries[level].append(higher_summary)
                self.summaries[prev_level] = []
    
    def get_full_summary(self) -> str:
        """获取完整摘要层级"""
        parts = []
        
        # 从高到低添加摘要
        for level in [2, 1, 0]:
            if self.summaries[level]:
                parts.append(f"【历史摘要 L{level}】")
                for summary in self.summaries[level]:
                    parts.append(f"- {summary}")
                parts.append("")
        
        # 添加未摘要的消息
        if self.messages_buffer:
            parts.append("【最新对话】")
            for msg in self.messages_buffer:
                parts.append(f"{msg.role}: {msg.content}")
        
        return "\n".join(parts)

4.2 选择性记忆压缩

不是所有记忆都值得保留,需要选择性压缩。

class SelectiveCompressor:
    """选择性记忆压缩器"""
    
    def __init__(self, llm: ChatOpenAI = None):
        self.llm = llm or ChatOpenAI(model="gpt-3.5-turbo")
    
    async def compress_memory(self, 
                             original: str,
                             target_ratio: float = 0.3) -> str:
        """压缩记忆内容"""
        original_tokens = len(tiktoken.encoding_for_model("gpt-4").encode(original))
        target_tokens = int(original_tokens * target_ratio)
        
        prompt = f"""请将以下内容压缩为更简洁的形式,保留所有关键信息。

原始内容(约{original_tokens} tokens):
{original}

要求:
1. 压缩到约{target_tokens} tokens
2. 使用简洁的表达
3. 保留关键实体、关系和约束
4. 可以改用列表、表格等结构化形式

压缩后:"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            compressed = response.content.strip()
            
            # 验证压缩率
            compressed_tokens = len(
                tiktoken.encoding_for_model("gpt-4").encode(compressed)
            )
            actual_ratio = compressed_tokens / original_tokens
            
            return compressed, actual_ratio
        except Exception as e:
            print(f"压缩失败: {e}")
            return original, 1.0
    
    async def extract_key_facts(self, content: str) -> List[str]:
        """提取关键事实"""
        prompt = f"""从以下内容中提取关键事实,每条事实一行:

内容:
{content}

要求:
1. 每条事实独立完整
2. 去除冗余和重复
3. 保留数值、时间、名称等具体信息
4. 最多提取10条最重要的事实

关键事实:"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            facts = [f.strip() for f in response.content.strip().split("\n") if f.strip()]
            return facts[:10]
        except Exception as e:
            print(f"事实提取失败: {e}")
            return [content[:100]]
    
    async def merge_memories(self, 
                            memories: List[str]) -> str:
        """合并多个记忆"""
        combined = "\n\n".join([f"记忆{i+1}: {m}" for i, m in enumerate(memories)])
        
        prompt = f"""请将以下多条记忆整合为一条连贯的记忆:

{combined}

要求:
1. 合并重复或相关的信息
2. 解决矛盾(以最新信息为准)
3. 保持逻辑连贯
4. 去除冗余表达

整合后记忆:"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            return response.content.strip()
        except Exception as e:
            print(f"合并失败: {e}")
            return "\n".join(memories)

五、完整实战示例

5.1 生产级记忆系统

import asyncio
from typing import AsyncGenerator

class ProductionMemorySystem:
    """生产级记忆系统"""
    
    def __init__(self, 
                 openai_api_key: str = None,
                 vector_db_path: str = "./production_memory"):
        # 初始化 LLM
        self.llm = ChatOpenAI(
            model="gpt-4",
            api_key=openai_api_key,
            temperature=0
        )
        
        # 初始化各组件
        self.short_term = SmartConversationBuffer(max_tokens=4000)
        self.hybrid_memory = HybridMemory(llm=self.llm)
        self.vector_memory = VectorMemoryManager(
            persist_directory=vector_db_path
        )
        self.tiered_store = TieredMemoryStore()
        self.context_assembler = ContextAssembler()
        self.summarizer = ProgressiveSummarizer(llm=self.llm)
        self.compressor = SelectiveCompressor(llm=self.llm)
        
        # 统计信息
        self.stats = {
            "total_interactions": 0,
            "memories_stored": 0,
            "memories_retrieved": 0
        }
    
    async def process_interaction(self,
                                 user_id: str,
                                 user_input: str,
                                 agent_response: str) -> Dict:
        """处理一次完整交互"""
        self.stats["total_interactions"] += 1
        
        # 1. 更新短期记忆
        self.short_term.add_message("user", user_input)
        self.short_term.add_message("assistant", agent_response)
        
        # 2. 更新混合记忆
        await self.hybrid_memory.add_message("user", user_input)
        await self.hybrid_memory.add_message("assistant", agent_response)
        
        # 3. 提取并存储关键记忆
        await self._extract_and_store_memories(user_id, user_input, agent_response)
        
        # 4. 触发摘要
        await self.summarizer.add_message(
            Message(role="user", content=user_input)
        )
        
        return {
            "status": "success",
            "short_term_stats": self.short_term.get_stats()
        }
    
    async def _extract_and_store_memories(self,
                                         user_id: str,
                                         user_input: str,
                                         agent_response: str):
        """提取并存储关键记忆"""
        # 使用 LLM 提取关键信息
        prompt = f"""从以下对话中提取需要长期记忆的关键信息:

用户: {user_input}
助手: {agent_response}

请提取(JSON格式):
{{
    "facts": ["事实1", "事实2"],
    "preferences": ["偏好1"],
    "context": ["上下文信息"]
}}

只输出JSON,不要其他内容。"""
        
        try:
            response = await self.llm.ainvoke([HumanMessage(content=prompt)])
            extracted = json.loads(response.content)
            
            # 存储到向量数据库
            for category, items in extracted.items():
                for item in items:
                    memory_id = await self.vector_memory.add_memory(
                        content=item,
                        memory_type=category,
                        metadata={
                            "user_id": user_id,
                            "source": "conversation_extraction"
                        }
                    )
                    self.stats["memories_stored"] += 1
                    
                    # 同时存储到分层存储
                    self.tiered_store.store(
                        memory_id, item, category,
                        metadata={"user_id": user_id}
                    )
        
        except Exception as e:
            print(f"记忆提取失败: {e}")
    
    async def retrieve_context(self,
                              query: str,
                              user_id: str = None) -> str:
        """检索上下文"""
        # 1. 检索向量记忆
        vector_results = await self.vector_memory.search(
            query, top_k=5
        )
        self.stats["memories_retrieved"] += len(vector_results)
        
        # 2. 获取短期记忆
        short_term_messages = self.short_term.get_messages(last_n=6)
        
        # 3. 组装上下文
        context = self.context_assembler.assemble_context(
            query=query,
            retrieved_memories=vector_results,
            current_conversation=short_term_messages
        )
        
        return context
    
    async def get_memory_insights(self, user_id: str) -> Dict:
        """获取记忆洞察"""
        return {
            "total_interactions": self.stats["total_interactions"],
            "memories_stored": self.stats["memories_stored"],
            "short_term_stats": self.short_term.get_stats(),
            "summary": self.summarizer.get_full_summary()
        }


# 使用示例
async def demo_production_system():
    """演示生产级记忆系统"""
    system = ProductionMemorySystem()
    user_id = "user_demo_001"
    
    # 模拟多轮对话
    interactions = [
        ("你好,我叫张三,是一名软件工程师", 
         "你好张三!很高兴认识你。作为软件工程师,你对什么技术最感兴趣?"),
        ("我主要做后端开发,用 Python 和 Go", 
         "Python 和 Go 都是很棒的后端语言!Python 适合快速开发,Go 适合高并发。"),
        ("是的,我喜欢 Python 的简洁,但 Go 的性能更好", 
         "很好的观察。你平时用什么框架?Django、Flask 还是 FastAPI?"),
        ("我更喜欢 FastAPI,类型提示和异步支持很好", 
         "FastAPI 确实很棒!类型安全和自动文档生成是它的亮点。"),
        ("对了,能推荐一些学习资源吗?", 
         "当然可以!FastAPI 官方文档很详细,还有 TestDriven.io 的教程很不错。"),
    ]
    
    print("="*60)
    print("开始模拟对话...")
    print("="*60)
    
    for user_input, agent_response in interactions:
        result = await system.process_interaction(
            user_id, user_input, agent_response
        )
        print(f"\n用户: {user_input}")
        print(f"助手: {agent_response[:50]}...")
    
    print("\n" + "="*60)
    print("检索测试...")
    print("="*60)
    
    # 测试检索
    queries = [
        "用户叫什么名字?",
        "用户喜欢什么编程语言?",
        "推荐了什么学习资源?"
    ]
    
    for query in queries:
        context = await system.retrieve_context(query, user_id)
        print(f"\n查询: {query}")
        print(f"检索到的上下文:\n{context[:500]}...")
    
    print("\n" + "="*60)
    print("记忆统计")
    print("="*60)
    insights = await system.get_memory_insights(user_id)
    print(json.dumps(insights, indent=2, ensure_ascii=False))


# 运行演示
if __name__ == "__main__":
    asyncio.run(demo_production_system())

5.2 记忆管理最佳实践清单

MEMORY_MANAGEMENT_CHECKLIST = {
    "短期记忆": [
        "✓ 实现 Token 预算管理,避免超出模型限制",
        "✓ 使用重要性评分,优先保留关键消息",
        "✓ 采用滑动窗口 + 摘要的混合策略",
        "✓ 系统消息始终保留,确保行为一致性",
        "✓ 定期清理过期的临时状态"
    ],
    
    "长期记忆": [
        "✓ 使用分层存储(热/温/冷/归档)",
        "✓ 实现记忆生命周期管理(TTL + 衰减)",
        "✓ 定期清理低价值记忆",
        "✓ 重要记忆多副本存储",
        "✓ 支持记忆版本控制"
    ],
    
    "向量记忆": [
        "✓ 选择合适的 Embedding 模型",
        "✓ 设置合理的相似度阈值",
        "✓ 实现混合搜索(语义 + 关键词)",
        "✓ 定期更新向量索引",
        "✓ 考虑向量量化和降维"
    ],
    
    "记忆压缩": [
        "✓ 渐进式摘要,避免信息丢失",
        "✓ 选择性压缩,保留关键事实",
        "✓ 合并相似记忆,减少冗余",
        "✓ 监控压缩率,确保质量",
        "✓ 支持记忆解压查看原始内容"
    ],
    
    "性能优化": [
        "✓ 使用缓存加速频繁访问",
        "✓ 异步处理记忆存储",
        "✓ 批量操作减少 IO",
        "✓ 监控记忆系统性能指标",
        "✓ 实现降级策略应对高负载"
    ],
    
    "数据安全": [
        "✓ 敏感信息加密存储",
        "✓ 实现访问控制和审计",
        "✓ 支持记忆导出和删除(GDPR)",
        "✓ 定期备份重要记忆",
        "✓ 隔离不同用户的记忆数据"
    ]
}

六、总结

本文系统介绍了 AI Agent 记忆管理的核心策略:

核心策略回顾

策略类型关键要点适用场景
短期记忆Token 预算、重要性评分、滑动窗口当前对话上下文
长期记忆分层存储、生命周期、冷热分离持久化知识存储
向量记忆语义检索、混合搜索、相似度阈值语义相关的记忆召回
记忆压缩渐进摘要、选择性压缩、智能合并长对话历史处理

关键设计原则

  1. 分层管理:根据访问频率和重要性分层存储
  2. 动态平衡:在完整性和效率之间动态调整
  3. 智能压缩:保留关键信息,去除冗余
  4. 安全优先:敏感数据加密,隐私合规
  5. 可观测性:监控记忆系统状态,及时优化

下一步学习


本文最后更新于 2025-05-07,如有问题欢迎在社区讨论。