AI Agent 工具设计与注册:从原理到实践
在构建 AI Agent 时,工具(Tools)是连接大语言模型(LLM)与外部世界的桥梁。一个设计良好的工具系统不仅能提升 Agent 的能力边界,还能确保系统的安全性与可维护性。本文将深入探讨工具设计的核心原则、主流框架的注册机制以及新兴的行业标准。
1. 工具设计原则
1.1 命名规范
工具名称是 LLM 理解工具功能的第一入口,良好的命名能显著提升工具选择的准确性。
命名最佳实践
| 原则 | 说明 | 示例 |
|---|---|---|
| 动词优先 | 使用动词开头描述动作 | search_documents, send_email |
| ** snake_case** | 使用小写和下划线分隔 | get_weather_data ✓ vs getWeatherData ✗ |
| 具体明确 | 避免模糊词汇,精确描述功能 | calculate_mortgage_payment vs do_math |
| 避免冲突 | 确保名称在工具集中唯一 | 区分 search_web 和 search_database |
提示
避免使用过于通用的名称如 process、handle、do 等,这些词汇无法给模型提供足够的选择依据。
命名示例对比
# ❌ 不好的命名
"tool1" # 无意义
"process" # 过于通用
"getData" # 驼峰命名,不符合规范
"weather-stuff" # 模糊且不标准
# ✅ 好的命名
"get_current_weather" # 清晰、规范
"search_semantic_documents" # 具体描述功能
"send_slack_notification" # 动作 + 目标 + 对象
"calculate_currency_exchange" # 计算 + 业务领域
1.2 描述设计
工具描述是 LLM 决定是否调用该工具的关键依据。描述应当:
- 说明功能:这个工具做什么?
- 说明输入:需要什么参数?
- 说明输出:返回什么结果?
- 说明场景:什么时候应该使用?
# 示例:良好的工具描述
tool_description = """
Search for documents in the vector database using semantic similarity.
Use this tool when you need to find relevant information from the company's
knowledge base based on a natural language query. Do NOT use this for real-time
information like current weather or stock prices.
Args:
query: The search query in natural language
top_k: Number of results to return (default: 5)
filter: Optional metadata filter (e.g., {"department": "engineering"})
Returns:
List of documents with content and relevance scores
"""
1.3 参数设计
参数设计直接影响 LLM 能否正确调用工具。遵循 JSON Schema 规范,同时考虑 LLM 的理解能力。
参数设计原则
from pydantic import BaseModel, Field
from typing import Optional, Literal
class WeatherInput(BaseModel):
"""Input schema for weather tool."""
location: str = Field(
description="The city and state/country, e.g., 'San Francisco, CA' or 'Paris, France'",
examples=["Beijing, China", "New York, NY"]
)
unit: Literal["celsius", "fahrenheit"] = Field(
default="celsius",
description="Temperature unit to use for the forecast"
)
days: int = Field(
default=1,
ge=1,
le=14,
description="Number of days to forecast (1-14)"
)
include_humidity: Optional[bool] = Field(
default=True,
description="Whether to include humidity data in the response"
)
参数类型最佳实践
| 类型 | 使用场景 | 示例 |
|---|---|---|
string | 文本输入、枚举值 | "search_query", "user_email" |
integer | 计数、ID、数量 | top_k, page_number |
number | 浮点数值 | temperature, confidence_score |
boolean | 开关选项 | include_metadata, async_mode |
array | 列表输入 | tags, document_ids |
object | 结构化数据 | filters, config |
提示
对于复杂对象参数,提供 examples 字段展示期望的输入格式,这能显著提升 LLM 的参数构造准确性。
2. LangChain 工具注册
2.1 使用装饰器注册
LangChain 提供了简洁的装饰器语法来注册工具:
from langchain.tools import tool
from typing import Annotated
import requests
@tool
def get_weather(
location: Annotated[str, "The city name, e.g., 'Beijing' or 'New York'"],
unit: Annotated[str, "Temperature unit: 'celsius' or 'fahrenheit'"] = "celsius"
) -> str:
"""Get current weather information for a specified location.
Use this tool when the user asks about weather conditions.
Returns temperature, humidity, and weather description.
"""
api_url = f"https://api.weather.com/v1/current?city={location}&unit={unit}"
response = requests.get(api_url)
data = response.json()
return f"Current weather in {location}: {data['temperature']}°{unit.upper()}, {data['description']}"
# 查看工具信息
print(get_weather.name)
print(get_weather.description)
print(get_weather.args)
2.2 使用 BaseTool 类
对于更复杂的工具,可以继承 BaseTool 类:
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
import asyncio
class SearchInput(BaseModel):
query: str = Field(description="Search query string")
max_results: int = Field(default=10, ge=1, le=100, description="Maximum results to return")
class DocumentSearchTool(BaseTool):
name: str = "semantic_document_search"
description: str = """
Search documents using semantic similarity in the vector database.
Use this for finding relevant information from internal knowledge base.
"""
args_schema: type[BaseModel] = SearchInput
vector_store: Any = Field(default=None, exclude=True)
def __init__(self, vector_store, **kwargs):
super().__init__(**kwargs)
self.vector_store = vector_store
def _run(self, query: str, max_results: int = 10) -> str:
"""Synchronous execution."""
results = self.vector_store.similarity_search(query, k=max_results)
return self._format_results(results)
async def _arun(self, query: str, max_results: int = 10) -> str:
"""Asynchronous execution."""
results = await self.vector_store.asimilarity_search(query, k=max_results)
return self._format_results(results)
def _format_results(self, results) -> str:
formatted = []
for i, doc in enumerate(results, 1):
formatted.append(f"{i}. {doc.page_content[:200]}... (Score: {doc.metadata.get('score', 'N/A')})")
return "\n".join(formatted)
# 使用工具
search_tool = DocumentSearchTool(vector_store=your_vector_store)
2.3 工具集管理
from langchain.agents import Tool
from langchain.tools import format_tool_to_openai_function
# 创建工具列表
tools = [
Tool(
name="web_search",
func=search_engine.run,
description="Search the internet for current information"
),
get_weather,
DocumentSearchTool(vector_store=vs)
]
# 转换为 OpenAI 函数格式
functions = [format_tool_to_openai_function(t) for t in tools]
3. LangGraph 工具注册
LangGraph 提供了更灵活的工具管理机制,特别适合构建复杂的 Agent 工作流。
3.1 基础工具节点
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from typing import TypedDict, Annotated, Sequence
import operator
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next_step: str
# 定义工具
@tool
def calculator(expression: str) -> str:
"""Evaluate a mathematical expression safely."""
try:
# 使用安全的 eval 替代方案
result = safe_eval(expression)
return f"Result: {result}"
except Exception as e:
return f"Error: {str(e)}"
@tool
def code_executor(code: str, language: str = "python") -> str:
"""Execute code in a sandboxed environment."""
# 实现沙箱执行逻辑
return sandbox.execute(code, language=language)
# 创建工具节点
tools = [calculator, code_executor]
tool_node = ToolNode(tools)
# 构建图
workflow = StateGraph(AgentState)
# 添加工具节点
workflow.add_node("tools", tool_node)
# 条件边:决定是否需要调用工具
def should_call_tools(state: AgentState):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "tools"
return END
workflow.add_conditional_edges(
"agent",
should_call_tools,
{"tools": "tools", END: END}
)
workflow.add_edge("tools", "agent")
3.2 自定义工具节点
对于需要特殊处理的工具,可以创建自定义节点:
from langchain_core.messages import ToolMessage
class CustomToolNode:
"""Custom tool node with logging and error handling."""
def __init__(self, tools: list, logger=None):
self.tools_by_name = {tool.name: tool for tool in tools}
self.logger = logger or logging.getLogger(__name__)
def __call__(self, state: AgentState):
outputs = []
for tool_call in state["messages"][-1].tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_call_id = tool_call["id"]
self.logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
try:
tool = self.tools_by_name.get(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found")
# 执行前权限检查
if not self._check_permissions(tool_name, state):
raise PermissionError(f"No permission to use tool: {tool_name}")
# 执行工具
observation = tool.invoke(tool_args)
# 记录执行结果
self.logger.info(f"Tool {tool_name} completed successfully")
outputs.append(
ToolMessage(
content=str(observation),
name=tool_name,
tool_call_id=tool_call_id
)
)
except Exception as e:
self.logger.error(f"Tool {tool_name} failed: {str(e)}")
outputs.append(
ToolMessage(
content=f"Error: {str(e)}",
name=tool_name,
tool_call_id=tool_call_id
)
)
return {"messages": outputs}
def _check_permissions(self, tool_name: str, state: AgentState) -> bool:
# 实现权限检查逻辑
allowed_tools = state.get("allowed_tools", [])
return tool_name in allowed_tools or not allowed_tools
3.3 工具链组合
from langgraph.graph import MessageGraph
# 创建顺序执行的工具链
def create_tool_chain(tools: list):
"""Create a sequential tool execution chain."""
builder = MessageGraph()
# 添加每个工具作为节点
for i, tool in enumerate(tools):
node_name = f"tool_{i}"
builder.add_node(node_name, tool)
if i > 0:
builder.add_edge(f"tool_{i-1}", node_name)
builder.set_entry_point("tool_0")
builder.set_finish_point(f"tool_{len(tools)-1}")
return builder.compile()
# 使用示例
tool_chain = create_tool_chain([fetch_data_tool, process_data_tool, save_result_tool])
result = tool_chain.invoke("initial_input")
4. MCP 工具标准
Model Context Protocol (MCP) 是 Anthropic 推出的开放标准,旨在标准化 AI 模型与外部工具的交互方式。
4.1 MCP 架构概览
┌─────────────────────────────────────────────────────────┐
│ AI Application │
│ (Host/Client) │
└────────────────────┬────────────────────────────────────┘
│ MCP Protocol
▼
┌─────────────────────────────────────────────────────────┐
│ MCP Server │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
│ │ Tools │ │ Resources │ │ Prompts │ │
│ │ (Functions)│ │ (Context) │ │ (Templates) │ │
│ └─────────────┘ └─────────────┘ └─────────────────┘ │
└─────────────────────────────────────────────────────────┘
4.2 实现 MCP 服务器
from mcp.server import Server
from mcp.types import Tool, TextContent
import asyncio
# 创建 MCP 服务器
app = Server("my-ai-tools")
# 定义可用工具
@app.list_tools()
async def list_tools() -> list[Tool]:
return [
Tool(
name="file_reader",
description="Read contents of a file",
inputSchema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file"
},
"encoding": {
"type": "string",
"enum": ["utf-8", "ascii", "latin-1"],
"default": "utf-8"
}
},
"required": ["path"]
}
),
Tool(
name="database_query",
description="Execute a read-only SQL query",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query"
},
"limit": {
"type": "integer",
"maximum": 1000,
"default": 100
}
},
"required": ["query"]
}
)
]
# 实现工具调用处理
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
if name == "file_reader":
content = await read_file_safe(
arguments["path"],
encoding=arguments.get("encoding", "utf-8")
)
return [TextContent(type="text", text=content)]
elif name == "database_query":
results = await execute_readonly_query(
arguments["query"],
limit=arguments.get("limit", 100)
)
return [TextContent(type="text", text=format_results(results))]
else:
raise ValueError(f"Unknown tool: {name}")
# 启动服务器
async def main():
from mcp.server.stdio import stdio_server
async with stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options()
)
if __name__ == "__main__":
asyncio.run(main())
4.3 MCP 客户端集成
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class MCPToolManager:
"""Manage MCP tool connections and execution."""
def __init__(self):
self.sessions: dict[str, ClientSession] = {}
self.tools_cache: dict[str, list] = {}
async def connect_server(self, server_name: str, command: str, args: list = None):
"""Connect to an MCP server."""
server_params = StdioServerParameters(
command=command,
args=args or [],
env=None
)
async with stdio_client(server_params) as (read, write):
session = await ClientSession(read, write).__aenter__()
await session.initialize()
self.sessions[server_name] = session
self.tools_cache[server_name] = await session.list_tools()
async def execute_tool(self, server_name: str, tool_name: str, arguments: dict):
"""Execute a tool on a specific server."""
session = self.sessions.get(server_name)
if not session:
raise ValueError(f"Server '{server_name}' not connected")
result = await session.call_tool(tool_name, arguments)
return result
def get_all_tools(self) -> list[dict]:
"""Get all available tools from all connected servers."""
all_tools = []
for server_name, tools in self.tools_cache.items():
for tool in tools:
all_tools.append({
"server": server_name,
"name": tool.name,
"description": tool.description,
"schema": tool.inputSchema
})
return all_tools
# 使用示例
async def setup_mcp_tools():
manager = MCPToolManager()
# 连接文件系统 MCP 服务器
await manager.connect_server(
"filesystem",
"python",
["-m", "mcp_server_filesystem", "/home/user/documents"]
)
# 连接数据库 MCP 服务器
await manager.connect_server(
"sqlite",
"python",
["-m", "mcp_server_sqlite", "--db-path", "/path/to/db.sqlite"]
)
# 获取所有可用工具
tools = manager.get_all_tools()
print(f"Available MCP tools: {len(tools)}")
return manager
4.4 MCP vs 传统工具对比
| 特性 | 传统工具 (LangChain) | MCP 标准 |
|---|---|---|
| 协议 | 框架特定 | 标准化协议 |
| 发现 | 代码注册 | 动态发现 |
| 传输 | 进程内/HTTP | stdio/SSE |
| 安全 | 应用层控制 | 内置权限 |
| 生态 | 框架绑定 | 跨框架兼容 |
| 资源 | 工具单一 | 工具+资源+提示 |
5. 工具安全与权限控制
5.1 输入验证与消毒
from pydantic import BaseModel, validator, Field
import re
from html import escape
class SafeSearchInput(BaseModel):
query: str = Field(..., max_length=500)
max_results: int = Field(default=10, ge=1, le=100)
@validator('query')
def sanitize_query(cls, v):
# 移除潜在危险字符
v = re.sub(r'[<>\"\']', '', v)
# 防止命令注入
dangerous_patterns = [
r';\s*rm\s+-rf',
r'\|\s*sh',
r'`.*?`',
r'\$\(.*?\)'
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("Potentially dangerous pattern detected")
return v.strip()
class FileOperationInput(BaseModel):
path: str = Field(...)
@validator('path')
def validate_path(cls, v, values):
# 解析并规范化路径
import os
base_dir = values.get('base_dir', '/allowed/path')
full_path = os.path.abspath(os.path.join(base_dir, v))
# 确保路径在允许范围内
if not full_path.startswith(base_dir):
raise ValueError("Access denied: path outside allowed directory")
# 检查路径遍历攻击
if '..' in v or '~' in v:
raise ValueError("Path traversal attempt detected")
return full_path
5.2 权限控制系统
from enum import Enum
from functools import wraps
from typing import Callable
class PermissionLevel(Enum):
READ_ONLY = "read_only"
STANDARD = "standard"
PRIVILEGED = "privileged"
ADMIN = "admin"
class ToolPermission:
"""Tool permission decorator."""
def __init__(self, required_level: PermissionLevel):
self.required_level = required_level
self.levels = list(PermissionLevel)
def __call__(self, func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
# 从 context 获取当前用户权限
current_level = kwargs.get('_user_permission', PermissionLevel.READ_ONLY)
if self.levels.index(current_level) < self.levels.index(self.required_level):
raise PermissionError(
f"Tool '{func.__name__}' requires {self.required_level.value} permission, "
f"but user has {current_level.value}"
)
return func(*args, **kwargs)
return wrapper
# 使用示例
class SecureTools:
@ToolPermission(PermissionLevel.READ_ONLY)
def search_documents(self, query: str, **kwargs):
"""Search documents - available to all users."""
return document_store.search(query)
@ToolPermission(PermissionLevel.STANDARD)
def send_notification(self, message: str, channel: str, **kwargs):
"""Send notifications - requires standard permission."""
return notification_service.send(channel, message)
@ToolPermission(PermissionLevel.PRIVILEGED)
def modify_database(self, operation: str, data: dict, **kwargs):
"""Modify database - requires privileged permission."""
return db.execute(operation, data)
@ToolPermission(PermissionLevel.ADMIN)
def system_command(self, command: str, **kwargs):
"""Execute system commands - admin only."""
return os.system(command)
5.3 执行沙箱
import subprocess
import tempfile
import os
from pathlib import Path
class SandboxExecutor:
"""Execute code in an isolated sandbox environment."""
def __init__(self, timeout: int = 30, memory_limit: str = "512m"):
self.timeout = timeout
self.memory_limit = memory_limit
self.allowed_imports = {
'math', 'random', 'datetime', 'json', 're',
'collections', 'itertools', 'statistics'
}
def execute_python(self, code: str) -> dict:
"""Execute Python code in sandbox."""
# 静态代码分析
self._validate_code(code)
with tempfile.TemporaryDirectory() as tmpdir:
# 写入代码文件
code_file = Path(tmpdir) / "script.py"
code_file.write_text(self._wrap_code(code))
# 使用 Docker 或受限环境执行
try:
result = subprocess.run(
[
"docker", "run", "--rm",
"-v", f"{tmpdir}:/code:ro",
"--memory", self.memory_limit,
"--cpus", "1.0",
"python:3.11-slim",
"python", "/code/script.py"
],
capture_output=True,
text=True,
timeout=self.timeout
)
return {
"success": result.returncode == 0,
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timed out after {self.timeout} seconds"
}
def _validate_code(self, code: str):
"""Validate code for dangerous patterns."""
import ast
try:
tree = ast.parse(code)
except SyntaxError as e:
raise ValueError(f"Invalid Python syntax: {e}")
for node in ast.walk(tree):
# 禁止导入
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in self.allowed_imports:
raise ValueError(f"Import '{alias.name}' not allowed")
# 禁止 exec/eval
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ('exec', 'eval', '__import__'):
raise ValueError(f"Function '{node.func.id}' not allowed")
# 禁止文件操作
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
if node.func.attr in ('open', 'read', 'write'):
raise ValueError("File operations not allowed")
def _wrap_code(self, code: str) -> str:
"""Wrap code with safety measures."""
return f'''
import sys
sys.path = []
{code}
'''
5.4 审计日志
import json
from datetime import datetime
from typing import Optional
import hashlib
class ToolAuditLogger:
"""Comprehensive audit logging for tool execution."""
def __init__(self, log_file: str = "tool_audit.log"):
self.log_file = log_file
def log_execution(
self,
tool_name: str,
user_id: str,
session_id: str,
inputs: dict,
outputs: Optional[str],
duration_ms: float,
success: bool,
error: Optional[str] = None
):
"""Log a tool execution event."""
# 对敏感输入进行哈希处理
sanitized_inputs = self._sanitize_inputs(inputs)
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"event_type": "tool_execution",
"tool_name": tool_name,
"user_id": hashlib.sha256(user_id.encode()).hexdigest()[:16],
"session_id": session_id,
"inputs_hash": hashlib.sha256(json.dumps(inputs, sort_keys=True).encode()).hexdigest(),
"inputs": sanitized_inputs,
"output_preview": outputs[:500] if outputs else None,
"duration_ms": duration_ms,
"success": success,
"error_type": type(error).__name__ if error else None
}
with open(self.log_file, "a") as f:
f.write(json.dumps(log_entry) + "\n")
def _sanitize_inputs(self, inputs: dict) -> dict:
"""Remove or mask sensitive fields."""
sensitive_keys = {'password', 'token', 'secret', 'key', 'api_key', 'credential'}
sanitized = {}
for key, value in inputs.items():
if any(s in key.lower() for s in sensitive_keys):
sanitized[key] = "***REDACTED***"
elif isinstance(value, dict):
sanitized[key] = self._sanitize_inputs(value)
else:
sanitized[key] = value
return sanitized
# 使用示例
audit_logger = ToolAuditLogger()
@tool
def secure_api_call(endpoint: str, api_key: str, data: dict) -> str:
"""Make secure API call with full audit logging."""
import time
start_time = time.time()
user_id = get_current_user_id()
session_id = get_current_session()
try:
result = make_api_request(endpoint, api_key, data)
audit_logger.log_execution(
tool_name="secure_api_call",
user_id=user_id,
session_id=session_id,
inputs={"endpoint": endpoint, "data": data},
outputs=result,
duration_ms=(time.time() - start_time) * 1000,
success=True
)
return result
except Exception as e:
audit_logger.log_execution(
tool_name="secure_api_call",
user_id=user_id,
session_id=session_id,
inputs={"endpoint": endpoint, "data": data},
outputs=None,
duration_ms=(time.time() - start_time) * 1000,
success=False,
error=e
)
raise
6. 完整示例:企业级工具系统
# enterprise_tools.py
from langchain.tools import tool, BaseTool
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from typing import Annotated, Optional, Literal
import asyncio
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ==================== 输入模型定义 ====================
class SearchInput(BaseModel):
query: str = Field(description="Search query", max_length=500)
filters: Optional[dict] = Field(default=None, description="Metadata filters")
top_k: int = Field(default=5, ge=1, le=20, description="Number of results")
class DatabaseQueryInput(BaseModel):
query: str = Field(description="SQL query (SELECT only)", max_length=1000)
params: Optional[dict] = Field(default=None, description="Query parameters")
class NotificationInput(BaseModel):
channel: Literal["email", "slack", "teams"] = Field(description="Notification channel")
recipient: str = Field(description="Recipient address/user")
message: str = Field(description="Message content", max_length=2000)
priority: Literal["low", "normal", "high"] = Field(default="normal")
# ==================== 工具实现 ====================
class EnterpriseToolRegistry:
"""Enterprise-grade tool registry with security and audit features."""
def __init__(self):
self.tools = {}
self.permissions = {}
self.audit_log = []
def register(self, tool_func, permission_level="standard"):
"""Register a tool with permission level."""
self.tools[tool_func.name] = tool_func
self.permissions[tool_func.name] = permission_level
logger.info(f"Registered tool: {tool_func.name} ({permission_level})")
return tool_func
def get_tools(self, user_permission: str = "standard"):
"""Get tools accessible to user permission level."""
permission_hierarchy = ["read_only", "standard", "privileged", "admin"]
user_level = permission_hierarchy.index(user_permission)
accessible = []
for name, tool in self.tools.items():
tool_level = permission_hierarchy.index(self.permissions[name])
if tool_level <= user_level:
accessible.append(tool)
return accessible
# 初始化注册表
registry = EnterpriseToolRegistry()
@tool(args_schema=SearchInput)
def enterprise_search(
query: str,
filters: Optional[dict] = None,
top_k: int = 5
) -> str:
"""
Search enterprise knowledge base using semantic similarity.
Use this tool to find relevant documents, policies, and information
from the company's internal knowledge repository.
"""
logger.info(f"Executing search: {query}")
# 模拟搜索实现
results = [
{"title": f"Result {i}", "content": f"Content for {query}", "score": 0.95 - i*0.05}
for i in range(top_k)
]
formatted = "\n\n".join([
f"{i+1}. {r['title']} (relevance: {r['score']:.2f})\n{r['content'][:200]}"
for i, r in enumerate(results)
])
return f"Found {len(results)} results:\n\n{formatted}"
@tool(args_schema=DatabaseQueryInput)
def readonly_database_query(
query: str,
params: Optional[dict] = None
) -> str:
"""
Execute read-only database queries for analytics.
Use this to retrieve business metrics, user statistics, and
other analytical data. Only SELECT queries are allowed.
"""
# 验证只读
if not query.strip().upper().startswith("SELECT"):
raise ValueError("Only SELECT queries are permitted")
logger.info(f"Executing query: {query[:50]}...")
# 模拟查询结果
return f"Query executed successfully. Returned 42 rows.\nColumns: id, name, value, timestamp"
@tool(args_schema=NotificationInput)
def send_notification(
channel: str,
recipient: str,
message: str,
priority: str = "normal"
) -> str:
"""
Send notifications via email, Slack, or Microsoft Teams.
Use this to alert users, send reports, or notify teams about
important events. Respect user notification preferences.
"""
logger.info(f"Sending {priority} {channel} to {recipient}")
# 模拟发送
return f"Notification sent via {channel} to {recipient} with {priority} priority"
# 注册工具
registry.register(enterprise_search, "read_only")
registry.register(readonly_database_query, "standard")
registry.register(send_notification, "privileged")
# ==================== LangGraph 集成 ====================
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_openai import ChatOpenAI
class AgentState(TypedDict):
messages: Annotated[list, operator.add]
user_permission: str
session_id: str
def create_enterprise_agent(user_permission: str = "standard"):
"""Create an enterprise agent with permission-aware tools."""
# 获取用户可访问的工具
tools = registry.get_tools(user_permission)
# 绑定工具到模型
model = ChatOpenAI(model="gpt-4-turbo-preview").bind_tools(tools)
# 创建图
workflow = StateGraph(AgentState)
# Agent 节点
def agent_node(state: AgentState):
messages = state["messages"]
response = model.invoke(messages)
return {"messages": [response]}
# 工具节点
def tool_node(state: AgentState):
last_message = state["messages"][-1]
tool_outputs = []
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
# 查找并执行工具
tool_func = registry.tools.get(tool_name)
if tool_func:
try:
output = tool_func.invoke(tool_args)
except Exception as e:
output = f"Error: {str(e)}"
else:
output = f"Tool {tool_name} not found"
tool_outputs.append(
ToolMessage(
content=str(output),
name=tool_name,
tool_call_id=tool_call["id"]
)
)
return {"messages": tool_outputs}
# 条件边
def should_continue(state: AgentState):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "tools"
return END
# 构建图
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{"tools": "tools", END: END}
)
workflow.add_edge("tools", "agent")
return workflow.compile()
# ==================== 使用示例 ====================
async def main():
"""Demonstrate enterprise agent usage."""
# 创建不同权限的 Agent
print("=== Read-only User ===")
read_only_agent = create_enterprise_agent("read_only")
result = await read_only_agent.ainvoke({
"messages": [HumanMessage(content="Search for Q4 financial reports")],
"user_permission": "read_only",
"session_id": "session_001"
})
print(result["messages"][-1].content)
print("\n=== Privileged User ===")
privileged_agent = create_enterprise_agent("privileged")
result = await privileged_agent.ainvoke({
"messages": [HumanMessage(content="Send a Slack notification to the team about the deployment")],
"user_permission": "privileged",
"session_id": "session_002"
})
print(result["messages"][-1].content)
if __name__ == "__main__":
asyncio.run(main())
7. 最佳实践总结
提示
工具设计黄金法则
- 单一职责:每个工具只做一件事,做好一件事
- 自描述:名称和描述足够清晰,无需额外文档
- 防御性编程:验证所有输入,假设所有输入都是恶意的
- 最小权限:只请求必要的权限,使用权限分级
- 可观测性:记录所有调用,便于调试和审计
- 优雅失败:提供清晰的错误信息,帮助 LLM 纠正
快速检查清单
- 工具名称使用动词开头,snake_case 格式
- 描述包含功能、输入、输出和使用场景
- 参数使用 Pydantic 模型,包含验证规则
- 敏感操作实现权限检查
- 所有工具调用记录审计日志
- 用户输入经过消毒和验证
- 危险操作在沙箱环境中执行
- 错误处理提供可操作的反馈
结语
工具设计与注册是构建生产级 AI Agent 的核心环节。通过遵循本文介绍的设计原则、利用 LangChain/LangGraph 的灵活机制、拥抱 MCP 等开放标准,并实施严格的安全控制,您可以构建出既强大又可靠的 Agent 系统。
随着 AI Agent 生态的快速发展,工具标准化和安全性将成为越来越重要的议题。建议持续关注 MCP 等新兴标准的发展,并在项目早期就建立完善的权限和审计机制。
本文示例代码可在 GitHub 获取完整版本,包含更多实战案例和测试用例。