AI Agent 进阶 AI Agent LangChain LangGraph MCP

AI Agent 工具设计与注册:从原理到实践

AIEng Hub
阅读约 25 分钟

AI Agent 工具设计与注册:从原理到实践

在构建 AI Agent 时,工具(Tools)是连接大语言模型(LLM)与外部世界的桥梁。一个设计良好的工具系统不仅能提升 Agent 的能力边界,还能确保系统的安全性与可维护性。本文将深入探讨工具设计的核心原则、主流框架的注册机制以及新兴的行业标准。

1. 工具设计原则

1.1 命名规范

工具名称是 LLM 理解工具功能的第一入口,良好的命名能显著提升工具选择的准确性。

命名最佳实践

原则说明示例
动词优先使用动词开头描述动作search_documents, send_email
** snake_case**使用小写和下划线分隔get_weather_data ✓ vs getWeatherData
具体明确避免模糊词汇,精确描述功能calculate_mortgage_payment vs do_math
避免冲突确保名称在工具集中唯一区分 search_websearch_database

提示

避免使用过于通用的名称如 processhandledo 等,这些词汇无法给模型提供足够的选择依据。

命名示例对比

# ❌ 不好的命名
"tool1"           # 无意义
"process"         # 过于通用
"getData"         # 驼峰命名,不符合规范
"weather-stuff"   # 模糊且不标准

# ✅ 好的命名
"get_current_weather"      # 清晰、规范
"search_semantic_documents" # 具体描述功能
"send_slack_notification"   # 动作 + 目标 + 对象
"calculate_currency_exchange" # 计算 + 业务领域

1.2 描述设计

工具描述是 LLM 决定是否调用该工具的关键依据。描述应当:

  1. 说明功能:这个工具做什么?
  2. 说明输入:需要什么参数?
  3. 说明输出:返回什么结果?
  4. 说明场景:什么时候应该使用?
# 示例:良好的工具描述
tool_description = """
Search for documents in the vector database using semantic similarity.

Use this tool when you need to find relevant information from the company's
knowledge base based on a natural language query. Do NOT use this for real-time
information like current weather or stock prices.

Args:
    query: The search query in natural language
    top_k: Number of results to return (default: 5)
    filter: Optional metadata filter (e.g., {"department": "engineering"})

Returns:
    List of documents with content and relevance scores
"""

1.3 参数设计

参数设计直接影响 LLM 能否正确调用工具。遵循 JSON Schema 规范,同时考虑 LLM 的理解能力。

参数设计原则

from pydantic import BaseModel, Field
from typing import Optional, Literal

class WeatherInput(BaseModel):
    """Input schema for weather tool."""
    
    location: str = Field(
        description="The city and state/country, e.g., 'San Francisco, CA' or 'Paris, France'",
        examples=["Beijing, China", "New York, NY"]
    )
    
    unit: Literal["celsius", "fahrenheit"] = Field(
        default="celsius",
        description="Temperature unit to use for the forecast"
    )
    
    days: int = Field(
        default=1,
        ge=1,
        le=14,
        description="Number of days to forecast (1-14)"
    )
    
    include_humidity: Optional[bool] = Field(
        default=True,
        description="Whether to include humidity data in the response"
    )

参数类型最佳实践

类型使用场景示例
string文本输入、枚举值"search_query", "user_email"
integer计数、ID、数量top_k, page_number
number浮点数值temperature, confidence_score
boolean开关选项include_metadata, async_mode
array列表输入tags, document_ids
object结构化数据filters, config

提示

对于复杂对象参数,提供 examples 字段展示期望的输入格式,这能显著提升 LLM 的参数构造准确性。

2. LangChain 工具注册

2.1 使用装饰器注册

LangChain 提供了简洁的装饰器语法来注册工具:

from langchain.tools import tool
from typing import Annotated
import requests

@tool
def get_weather(
    location: Annotated[str, "The city name, e.g., 'Beijing' or 'New York'"],
    unit: Annotated[str, "Temperature unit: 'celsius' or 'fahrenheit'"] = "celsius"
) -> str:
    """Get current weather information for a specified location.
    
    Use this tool when the user asks about weather conditions.
    Returns temperature, humidity, and weather description.
    """
    api_url = f"https://api.weather.com/v1/current?city={location}&unit={unit}"
    response = requests.get(api_url)
    data = response.json()
    
    return f"Current weather in {location}: {data['temperature']}°{unit.upper()}, {data['description']}"

# 查看工具信息
print(get_weather.name)
print(get_weather.description)
print(get_weather.args)

2.2 使用 BaseTool 类

对于更复杂的工具,可以继承 BaseTool 类:

from langchain.tools import BaseTool
from pydantic import BaseModel, Field
import asyncio

class SearchInput(BaseModel):
    query: str = Field(description="Search query string")
    max_results: int = Field(default=10, ge=1, le=100, description="Maximum results to return")

class DocumentSearchTool(BaseTool):
    name: str = "semantic_document_search"
    description: str = """
    Search documents using semantic similarity in the vector database.
    Use this for finding relevant information from internal knowledge base.
    """
    args_schema: type[BaseModel] = SearchInput
    
    vector_store: Any = Field(default=None, exclude=True)
    
    def __init__(self, vector_store, **kwargs):
        super().__init__(**kwargs)
        self.vector_store = vector_store
    
    def _run(self, query: str, max_results: int = 10) -> str:
        """Synchronous execution."""
        results = self.vector_store.similarity_search(query, k=max_results)
        return self._format_results(results)
    
    async def _arun(self, query: str, max_results: int = 10) -> str:
        """Asynchronous execution."""
        results = await self.vector_store.asimilarity_search(query, k=max_results)
        return self._format_results(results)
    
    def _format_results(self, results) -> str:
        formatted = []
        for i, doc in enumerate(results, 1):
            formatted.append(f"{i}. {doc.page_content[:200]}... (Score: {doc.metadata.get('score', 'N/A')})")
        return "\n".join(formatted)

# 使用工具
search_tool = DocumentSearchTool(vector_store=your_vector_store)

2.3 工具集管理

from langchain.agents import Tool
from langchain.tools import format_tool_to_openai_function

# 创建工具列表
tools = [
    Tool(
        name="web_search",
        func=search_engine.run,
        description="Search the internet for current information"
    ),
    get_weather,
    DocumentSearchTool(vector_store=vs)
]

# 转换为 OpenAI 函数格式
functions = [format_tool_to_openai_function(t) for t in tools]

3. LangGraph 工具注册

LangGraph 提供了更灵活的工具管理机制,特别适合构建复杂的 Agent 工作流。

3.1 基础工具节点

from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from typing import TypedDict, Annotated, Sequence
import operator

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next_step: str

# 定义工具
@tool
def calculator(expression: str) -> str:
    """Evaluate a mathematical expression safely."""
    try:
        # 使用安全的 eval 替代方案
        result = safe_eval(expression)
        return f"Result: {result}"
    except Exception as e:
        return f"Error: {str(e)}"

@tool
def code_executor(code: str, language: str = "python") -> str:
    """Execute code in a sandboxed environment."""
    # 实现沙箱执行逻辑
    return sandbox.execute(code, language=language)

# 创建工具节点
tools = [calculator, code_executor]
tool_node = ToolNode(tools)

# 构建图
workflow = StateGraph(AgentState)

# 添加工具节点
workflow.add_node("tools", tool_node)

# 条件边:决定是否需要调用工具
def should_call_tools(state: AgentState):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "tools"
    return END

workflow.add_conditional_edges(
    "agent",
    should_call_tools,
    {"tools": "tools", END: END}
)

workflow.add_edge("tools", "agent")

3.2 自定义工具节点

对于需要特殊处理的工具,可以创建自定义节点:

from langchain_core.messages import ToolMessage

class CustomToolNode:
    """Custom tool node with logging and error handling."""
    
    def __init__(self, tools: list, logger=None):
        self.tools_by_name = {tool.name: tool for tool in tools}
        self.logger = logger or logging.getLogger(__name__)
    
    def __call__(self, state: AgentState):
        outputs = []
        
        for tool_call in state["messages"][-1].tool_calls:
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool_call_id = tool_call["id"]
            
            self.logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
            
            try:
                tool = self.tools_by_name.get(tool_name)
                if not tool:
                    raise ValueError(f"Tool '{tool_name}' not found")
                
                # 执行前权限检查
                if not self._check_permissions(tool_name, state):
                    raise PermissionError(f"No permission to use tool: {tool_name}")
                
                # 执行工具
                observation = tool.invoke(tool_args)
                
                # 记录执行结果
                self.logger.info(f"Tool {tool_name} completed successfully")
                
                outputs.append(
                    ToolMessage(
                        content=str(observation),
                        name=tool_name,
                        tool_call_id=tool_call_id
                    )
                )
                
            except Exception as e:
                self.logger.error(f"Tool {tool_name} failed: {str(e)}")
                outputs.append(
                    ToolMessage(
                        content=f"Error: {str(e)}",
                        name=tool_name,
                        tool_call_id=tool_call_id
                    )
                )
        
        return {"messages": outputs}
    
    def _check_permissions(self, tool_name: str, state: AgentState) -> bool:
        # 实现权限检查逻辑
        allowed_tools = state.get("allowed_tools", [])
        return tool_name in allowed_tools or not allowed_tools

3.3 工具链组合

from langgraph.graph import MessageGraph

# 创建顺序执行的工具链
def create_tool_chain(tools: list):
    """Create a sequential tool execution chain."""
    builder = MessageGraph()
    
    # 添加每个工具作为节点
    for i, tool in enumerate(tools):
        node_name = f"tool_{i}"
        builder.add_node(node_name, tool)
        
        if i > 0:
            builder.add_edge(f"tool_{i-1}", node_name)
    
    builder.set_entry_point("tool_0")
    builder.set_finish_point(f"tool_{len(tools)-1}")
    
    return builder.compile()

# 使用示例
tool_chain = create_tool_chain([fetch_data_tool, process_data_tool, save_result_tool])
result = tool_chain.invoke("initial_input")

4. MCP 工具标准

Model Context Protocol (MCP) 是 Anthropic 推出的开放标准,旨在标准化 AI 模型与外部工具的交互方式。

4.1 MCP 架构概览

┌─────────────────────────────────────────────────────────┐
│                    AI Application                        │
│                   (Host/Client)                          │
└────────────────────┬────────────────────────────────────┘
                     │ MCP Protocol

┌─────────────────────────────────────────────────────────┐
│                   MCP Server                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────┐  │
│  │   Tools     │  │  Resources  │  │    Prompts      │  │
│  │  (Functions)│  │  (Context)  │  │  (Templates)    │  │
│  └─────────────┘  └─────────────┘  └─────────────────┘  │
└─────────────────────────────────────────────────────────┘

4.2 实现 MCP 服务器

from mcp.server import Server
from mcp.types import Tool, TextContent
import asyncio

# 创建 MCP 服务器
app = Server("my-ai-tools")

# 定义可用工具
@app.list_tools()
async def list_tools() -> list[Tool]:
    return [
        Tool(
            name="file_reader",
            description="Read contents of a file",
            inputSchema={
                "type": "object",
                "properties": {
                    "path": {
                        "type": "string",
                        "description": "Absolute path to the file"
                    },
                    "encoding": {
                        "type": "string",
                        "enum": ["utf-8", "ascii", "latin-1"],
                        "default": "utf-8"
                    }
                },
                "required": ["path"]
            }
        ),
        Tool(
            name="database_query",
            description="Execute a read-only SQL query",
            inputSchema={
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "SQL SELECT query"
                    },
                    "limit": {
                        "type": "integer",
                        "maximum": 1000,
                        "default": 100
                    }
                },
                "required": ["query"]
            }
        )
    ]

# 实现工具调用处理
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
    if name == "file_reader":
        content = await read_file_safe(
            arguments["path"],
            encoding=arguments.get("encoding", "utf-8")
        )
        return [TextContent(type="text", text=content)]
    
    elif name == "database_query":
        results = await execute_readonly_query(
            arguments["query"],
            limit=arguments.get("limit", 100)
        )
        return [TextContent(type="text", text=format_results(results))]
    
    else:
        raise ValueError(f"Unknown tool: {name}")

# 启动服务器
async def main():
    from mcp.server.stdio import stdio_server
    
    async with stdio_server() as (read_stream, write_stream):
        await app.run(
            read_stream,
            write_stream,
            app.create_initialization_options()
        )

if __name__ == "__main__":
    asyncio.run(main())

4.3 MCP 客户端集成

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

class MCPToolManager:
    """Manage MCP tool connections and execution."""
    
    def __init__(self):
        self.sessions: dict[str, ClientSession] = {}
        self.tools_cache: dict[str, list] = {}
    
    async def connect_server(self, server_name: str, command: str, args: list = None):
        """Connect to an MCP server."""
        server_params = StdioServerParameters(
            command=command,
            args=args or [],
            env=None
        )
        
        async with stdio_client(server_params) as (read, write):
            session = await ClientSession(read, write).__aenter__()
            await session.initialize()
            
            self.sessions[server_name] = session
            self.tools_cache[server_name] = await session.list_tools()
    
    async def execute_tool(self, server_name: str, tool_name: str, arguments: dict):
        """Execute a tool on a specific server."""
        session = self.sessions.get(server_name)
        if not session:
            raise ValueError(f"Server '{server_name}' not connected")
        
        result = await session.call_tool(tool_name, arguments)
        return result
    
    def get_all_tools(self) -> list[dict]:
        """Get all available tools from all connected servers."""
        all_tools = []
        for server_name, tools in self.tools_cache.items():
            for tool in tools:
                all_tools.append({
                    "server": server_name,
                    "name": tool.name,
                    "description": tool.description,
                    "schema": tool.inputSchema
                })
        return all_tools

# 使用示例
async def setup_mcp_tools():
    manager = MCPToolManager()
    
    # 连接文件系统 MCP 服务器
    await manager.connect_server(
        "filesystem",
        "python",
        ["-m", "mcp_server_filesystem", "/home/user/documents"]
    )
    
    # 连接数据库 MCP 服务器
    await manager.connect_server(
        "sqlite",
        "python",
        ["-m", "mcp_server_sqlite", "--db-path", "/path/to/db.sqlite"]
    )
    
    # 获取所有可用工具
    tools = manager.get_all_tools()
    print(f"Available MCP tools: {len(tools)}")
    
    return manager

4.4 MCP vs 传统工具对比

特性传统工具 (LangChain)MCP 标准
协议框架特定标准化协议
发现代码注册动态发现
传输进程内/HTTPstdio/SSE
安全应用层控制内置权限
生态框架绑定跨框架兼容
资源工具单一工具+资源+提示

5. 工具安全与权限控制

5.1 输入验证与消毒

from pydantic import BaseModel, validator, Field
import re
from html import escape

class SafeSearchInput(BaseModel):
    query: str = Field(..., max_length=500)
    max_results: int = Field(default=10, ge=1, le=100)
    
    @validator('query')
    def sanitize_query(cls, v):
        # 移除潜在危险字符
        v = re.sub(r'[<>\"\']', '', v)
        # 防止命令注入
        dangerous_patterns = [
            r';\s*rm\s+-rf',
            r'\|\s*sh',
            r'`.*?`',
            r'\$\(.*?\)'
        ]
        for pattern in dangerous_patterns:
            if re.search(pattern, v, re.IGNORECASE):
                raise ValueError("Potentially dangerous pattern detected")
        return v.strip()

class FileOperationInput(BaseModel):
    path: str = Field(...)
    
    @validator('path')
    def validate_path(cls, v, values):
        # 解析并规范化路径
        import os
        base_dir = values.get('base_dir', '/allowed/path')
        full_path = os.path.abspath(os.path.join(base_dir, v))
        
        # 确保路径在允许范围内
        if not full_path.startswith(base_dir):
            raise ValueError("Access denied: path outside allowed directory")
        
        # 检查路径遍历攻击
        if '..' in v or '~' in v:
            raise ValueError("Path traversal attempt detected")
        
        return full_path

5.2 权限控制系统

from enum import Enum
from functools import wraps
from typing import Callable

class PermissionLevel(Enum):
    READ_ONLY = "read_only"
    STANDARD = "standard"
    PRIVILEGED = "privileged"
    ADMIN = "admin"

class ToolPermission:
    """Tool permission decorator."""
    
    def __init__(self, required_level: PermissionLevel):
        self.required_level = required_level
        self.levels = list(PermissionLevel)
    
    def __call__(self, func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 从 context 获取当前用户权限
            current_level = kwargs.get('_user_permission', PermissionLevel.READ_ONLY)
            
            if self.levels.index(current_level) < self.levels.index(self.required_level):
                raise PermissionError(
                    f"Tool '{func.__name__}' requires {self.required_level.value} permission, "
                    f"but user has {current_level.value}"
                )
            
            return func(*args, **kwargs)
        
        return wrapper

# 使用示例
class SecureTools:
    
    @ToolPermission(PermissionLevel.READ_ONLY)
    def search_documents(self, query: str, **kwargs):
        """Search documents - available to all users."""
        return document_store.search(query)
    
    @ToolPermission(PermissionLevel.STANDARD)
    def send_notification(self, message: str, channel: str, **kwargs):
        """Send notifications - requires standard permission."""
        return notification_service.send(channel, message)
    
    @ToolPermission(PermissionLevel.PRIVILEGED)
    def modify_database(self, operation: str, data: dict, **kwargs):
        """Modify database - requires privileged permission."""
        return db.execute(operation, data)
    
    @ToolPermission(PermissionLevel.ADMIN)
    def system_command(self, command: str, **kwargs):
        """Execute system commands - admin only."""
        return os.system(command)

5.3 执行沙箱

import subprocess
import tempfile
import os
from pathlib import Path

class SandboxExecutor:
    """Execute code in an isolated sandbox environment."""
    
    def __init__(self, timeout: int = 30, memory_limit: str = "512m"):
        self.timeout = timeout
        self.memory_limit = memory_limit
        self.allowed_imports = {
            'math', 'random', 'datetime', 'json', 're', 
            'collections', 'itertools', 'statistics'
        }
    
    def execute_python(self, code: str) -> dict:
        """Execute Python code in sandbox."""
        # 静态代码分析
        self._validate_code(code)
        
        with tempfile.TemporaryDirectory() as tmpdir:
            # 写入代码文件
            code_file = Path(tmpdir) / "script.py"
            code_file.write_text(self._wrap_code(code))
            
            # 使用 Docker 或受限环境执行
            try:
                result = subprocess.run(
                    [
                        "docker", "run", "--rm",
                        "-v", f"{tmpdir}:/code:ro",
                        "--memory", self.memory_limit,
                        "--cpus", "1.0",
                        "python:3.11-slim",
                        "python", "/code/script.py"
                    ],
                    capture_output=True,
                    text=True,
                    timeout=self.timeout
                )
                
                return {
                    "success": result.returncode == 0,
                    "stdout": result.stdout,
                    "stderr": result.stderr,
                    "returncode": result.returncode
                }
                
            except subprocess.TimeoutExpired:
                return {
                    "success": False,
                    "error": f"Execution timed out after {self.timeout} seconds"
                }
    
    def _validate_code(self, code: str):
        """Validate code for dangerous patterns."""
        import ast
        
        try:
            tree = ast.parse(code)
        except SyntaxError as e:
            raise ValueError(f"Invalid Python syntax: {e}")
        
        for node in ast.walk(tree):
            # 禁止导入
            if isinstance(node, ast.Import):
                for alias in node.names:
                    if alias.name not in self.allowed_imports:
                        raise ValueError(f"Import '{alias.name}' not allowed")
            
            # 禁止 exec/eval
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Name):
                    if node.func.id in ('exec', 'eval', '__import__'):
                        raise ValueError(f"Function '{node.func.id}' not allowed")
            
            # 禁止文件操作
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Attribute):
                    if node.func.attr in ('open', 'read', 'write'):
                        raise ValueError("File operations not allowed")
    
    def _wrap_code(self, code: str) -> str:
        """Wrap code with safety measures."""
        return f'''
import sys
sys.path = []

{code}
'''

5.4 审计日志

import json
from datetime import datetime
from typing import Optional
import hashlib

class ToolAuditLogger:
    """Comprehensive audit logging for tool execution."""
    
    def __init__(self, log_file: str = "tool_audit.log"):
        self.log_file = log_file
    
    def log_execution(
        self,
        tool_name: str,
        user_id: str,
        session_id: str,
        inputs: dict,
        outputs: Optional[str],
        duration_ms: float,
        success: bool,
        error: Optional[str] = None
    ):
        """Log a tool execution event."""
        # 对敏感输入进行哈希处理
        sanitized_inputs = self._sanitize_inputs(inputs)
        
        log_entry = {
            "timestamp": datetime.utcnow().isoformat(),
            "event_type": "tool_execution",
            "tool_name": tool_name,
            "user_id": hashlib.sha256(user_id.encode()).hexdigest()[:16],
            "session_id": session_id,
            "inputs_hash": hashlib.sha256(json.dumps(inputs, sort_keys=True).encode()).hexdigest(),
            "inputs": sanitized_inputs,
            "output_preview": outputs[:500] if outputs else None,
            "duration_ms": duration_ms,
            "success": success,
            "error_type": type(error).__name__ if error else None
        }
        
        with open(self.log_file, "a") as f:
            f.write(json.dumps(log_entry) + "\n")
    
    def _sanitize_inputs(self, inputs: dict) -> dict:
        """Remove or mask sensitive fields."""
        sensitive_keys = {'password', 'token', 'secret', 'key', 'api_key', 'credential'}
        sanitized = {}
        
        for key, value in inputs.items():
            if any(s in key.lower() for s in sensitive_keys):
                sanitized[key] = "***REDACTED***"
            elif isinstance(value, dict):
                sanitized[key] = self._sanitize_inputs(value)
            else:
                sanitized[key] = value
        
        return sanitized

# 使用示例
audit_logger = ToolAuditLogger()

@tool
def secure_api_call(endpoint: str, api_key: str, data: dict) -> str:
    """Make secure API call with full audit logging."""
    import time
    
    start_time = time.time()
    user_id = get_current_user_id()
    session_id = get_current_session()
    
    try:
        result = make_api_request(endpoint, api_key, data)
        
        audit_logger.log_execution(
            tool_name="secure_api_call",
            user_id=user_id,
            session_id=session_id,
            inputs={"endpoint": endpoint, "data": data},
            outputs=result,
            duration_ms=(time.time() - start_time) * 1000,
            success=True
        )
        
        return result
        
    except Exception as e:
        audit_logger.log_execution(
            tool_name="secure_api_call",
            user_id=user_id,
            session_id=session_id,
            inputs={"endpoint": endpoint, "data": data},
            outputs=None,
            duration_ms=(time.time() - start_time) * 1000,
            success=False,
            error=e
        )
        raise

6. 完整示例:企业级工具系统

# enterprise_tools.py
from langchain.tools import tool, BaseTool
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from typing import Annotated, Optional, Literal
import asyncio
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ==================== 输入模型定义 ====================

class SearchInput(BaseModel):
    query: str = Field(description="Search query", max_length=500)
    filters: Optional[dict] = Field(default=None, description="Metadata filters")
    top_k: int = Field(default=5, ge=1, le=20, description="Number of results")

class DatabaseQueryInput(BaseModel):
    query: str = Field(description="SQL query (SELECT only)", max_length=1000)
    params: Optional[dict] = Field(default=None, description="Query parameters")

class NotificationInput(BaseModel):
    channel: Literal["email", "slack", "teams"] = Field(description="Notification channel")
    recipient: str = Field(description="Recipient address/user")
    message: str = Field(description="Message content", max_length=2000)
    priority: Literal["low", "normal", "high"] = Field(default="normal")

# ==================== 工具实现 ====================

class EnterpriseToolRegistry:
    """Enterprise-grade tool registry with security and audit features."""
    
    def __init__(self):
        self.tools = {}
        self.permissions = {}
        self.audit_log = []
    
    def register(self, tool_func, permission_level="standard"):
        """Register a tool with permission level."""
        self.tools[tool_func.name] = tool_func
        self.permissions[tool_func.name] = permission_level
        logger.info(f"Registered tool: {tool_func.name} ({permission_level})")
        return tool_func
    
    def get_tools(self, user_permission: str = "standard"):
        """Get tools accessible to user permission level."""
        permission_hierarchy = ["read_only", "standard", "privileged", "admin"]
        user_level = permission_hierarchy.index(user_permission)
        
        accessible = []
        for name, tool in self.tools.items():
            tool_level = permission_hierarchy.index(self.permissions[name])
            if tool_level <= user_level:
                accessible.append(tool)
        
        return accessible

# 初始化注册表
registry = EnterpriseToolRegistry()

@tool(args_schema=SearchInput)
def enterprise_search(
    query: str,
    filters: Optional[dict] = None,
    top_k: int = 5
) -> str:
    """
    Search enterprise knowledge base using semantic similarity.
    
    Use this tool to find relevant documents, policies, and information
    from the company's internal knowledge repository.
    """
    logger.info(f"Executing search: {query}")
    
    # 模拟搜索实现
    results = [
        {"title": f"Result {i}", "content": f"Content for {query}", "score": 0.95 - i*0.05}
        for i in range(top_k)
    ]
    
    formatted = "\n\n".join([
        f"{i+1}. {r['title']} (relevance: {r['score']:.2f})\n{r['content'][:200]}"
        for i, r in enumerate(results)
    ])
    
    return f"Found {len(results)} results:\n\n{formatted}"

@tool(args_schema=DatabaseQueryInput)
def readonly_database_query(
    query: str,
    params: Optional[dict] = None
) -> str:
    """
    Execute read-only database queries for analytics.
    
    Use this to retrieve business metrics, user statistics, and
    other analytical data. Only SELECT queries are allowed.
    """
    # 验证只读
    if not query.strip().upper().startswith("SELECT"):
        raise ValueError("Only SELECT queries are permitted")
    
    logger.info(f"Executing query: {query[:50]}...")
    
    # 模拟查询结果
    return f"Query executed successfully. Returned 42 rows.\nColumns: id, name, value, timestamp"

@tool(args_schema=NotificationInput)
def send_notification(
    channel: str,
    recipient: str,
    message: str,
    priority: str = "normal"
) -> str:
    """
    Send notifications via email, Slack, or Microsoft Teams.
    
    Use this to alert users, send reports, or notify teams about
    important events. Respect user notification preferences.
    """
    logger.info(f"Sending {priority} {channel} to {recipient}")
    
    # 模拟发送
    return f"Notification sent via {channel} to {recipient} with {priority} priority"

# 注册工具
registry.register(enterprise_search, "read_only")
registry.register(readonly_database_query, "standard")
registry.register(send_notification, "privileged")

# ==================== LangGraph 集成 ====================

from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_openai import ChatOpenAI

class AgentState(TypedDict):
    messages: Annotated[list, operator.add]
    user_permission: str
    session_id: str

def create_enterprise_agent(user_permission: str = "standard"):
    """Create an enterprise agent with permission-aware tools."""
    
    # 获取用户可访问的工具
    tools = registry.get_tools(user_permission)
    
    # 绑定工具到模型
    model = ChatOpenAI(model="gpt-4-turbo-preview").bind_tools(tools)
    
    # 创建图
    workflow = StateGraph(AgentState)
    
    # Agent 节点
    def agent_node(state: AgentState):
        messages = state["messages"]
        response = model.invoke(messages)
        return {"messages": [response]}
    
    # 工具节点
    def tool_node(state: AgentState):
        last_message = state["messages"][-1]
        tool_outputs = []
        
        for tool_call in last_message.tool_calls:
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            
            # 查找并执行工具
            tool_func = registry.tools.get(tool_name)
            if tool_func:
                try:
                    output = tool_func.invoke(tool_args)
                except Exception as e:
                    output = f"Error: {str(e)}"
            else:
                output = f"Tool {tool_name} not found"
            
            tool_outputs.append(
                ToolMessage(
                    content=str(output),
                    name=tool_name,
                    tool_call_id=tool_call["id"]
                )
            )
        
        return {"messages": tool_outputs}
    
    # 条件边
    def should_continue(state: AgentState):
        last_message = state["messages"][-1]
        if last_message.tool_calls:
            return "tools"
        return END
    
    # 构建图
    workflow.add_node("agent", agent_node)
    workflow.add_node("tools", tool_node)
    
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {"tools": "tools", END: END}
    )
    workflow.add_edge("tools", "agent")
    
    return workflow.compile()

# ==================== 使用示例 ====================

async def main():
    """Demonstrate enterprise agent usage."""
    
    # 创建不同权限的 Agent
    print("=== Read-only User ===")
    read_only_agent = create_enterprise_agent("read_only")
    
    result = await read_only_agent.ainvoke({
        "messages": [HumanMessage(content="Search for Q4 financial reports")],
        "user_permission": "read_only",
        "session_id": "session_001"
    })
    print(result["messages"][-1].content)
    
    print("\n=== Privileged User ===")
    privileged_agent = create_enterprise_agent("privileged")
    
    result = await privileged_agent.ainvoke({
        "messages": [HumanMessage(content="Send a Slack notification to the team about the deployment")],
        "user_permission": "privileged",
        "session_id": "session_002"
    })
    print(result["messages"][-1].content)

if __name__ == "__main__":
    asyncio.run(main())

7. 最佳实践总结

提示

工具设计黄金法则

  1. 单一职责:每个工具只做一件事,做好一件事
  2. 自描述:名称和描述足够清晰,无需额外文档
  3. 防御性编程:验证所有输入,假设所有输入都是恶意的
  4. 最小权限:只请求必要的权限,使用权限分级
  5. 可观测性:记录所有调用,便于调试和审计
  6. 优雅失败:提供清晰的错误信息,帮助 LLM 纠正

快速检查清单

  • 工具名称使用动词开头,snake_case 格式
  • 描述包含功能、输入、输出和使用场景
  • 参数使用 Pydantic 模型,包含验证规则
  • 敏感操作实现权限检查
  • 所有工具调用记录审计日志
  • 用户输入经过消毒和验证
  • 危险操作在沙箱环境中执行
  • 错误处理提供可操作的反馈

结语

工具设计与注册是构建生产级 AI Agent 的核心环节。通过遵循本文介绍的设计原则、利用 LangChain/LangGraph 的灵活机制、拥抱 MCP 等开放标准,并实施严格的安全控制,您可以构建出既强大又可靠的 Agent 系统。

随着 AI Agent 生态的快速发展,工具标准化和安全性将成为越来越重要的议题。建议持续关注 MCP 等新兴标准的发展,并在项目早期就建立完善的权限和审计机制。


本文示例代码可在 GitHub 获取完整版本,包含更多实战案例和测试用例。