RAG系统 进阶 RAG 文档解析 PDF 数据加载

文档加载与解析:RAG 系统的数据基础

AIEng Hub
阅读约 30 分钟

文档处理在 RAG 中的重要性

文档加载与解析是 RAG 系统的第一步,直接决定了后续检索和生成的质量。

┌─────────────────────────────────────────────────────────────┐
│              RAG 文档处理流程                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   原始文档                                                   │
│      │                                                      │
│      ▼                                                      │
│   ┌─────────────────┐                                        │
│   │   文档加载      │                                        │
│   │  (多格式支持)   │                                        │
│   └────────┬────────┘                                        │
│            │                                                │
│            ▼                                                │
│   ┌─────────────────┐                                        │
│   │   内容解析      │                                        │
│   │  (提取文本/元数据)│                                       │
│   └────────┬────────┘                                        │
│            │                                                │
│            ▼                                                │
│   ┌─────────────────┐                                        │
│   │   文本预处理    │                                        │
│   │  (清洗/标准化)  │                                        │
│   └────────┬────────┘                                        │
│            │                                                │
│            ▼                                                │
│   ┌─────────────────┐                                        │
│   │   文本分块      │                                        │
│   │  (Chunking)     │                                        │
│   └────────┬────────┘                                        │
│            │                                                │
│            ▼                                                │
│   ┌─────────────────┐                                        │
│   │   向量化        │                                        │
│   │  (Embedding)    │                                        │
│   └────────┬────────┘                                        │
│            │                                                │
│            ▼                                                │
│   向量数据库                                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

文档加载器

1. 使用 LangChain 加载器

# document_loaders.py
from langchain.document_loaders import (
    PyPDFLoader,
    Docx2txtLoader,
    TextLoader,
    UnstructuredHTMLLoader,
    CSVLoader,
    JSONLoader,
    UnstructuredMarkdownLoader,
    UnstructuredPowerPointLoader,
    UnstructuredExcelLoader
)
from langchain.document_loaders.directory import DirectoryLoader
import os

class DocumentLoader:
    """统一文档加载器"""
    
    LOADER_MAPPING = {
        ".pdf": PyPDFLoader,
        ".docx": Docx2txtLoader,
        ".txt": TextLoader,
        ".html": UnstructuredHTMLLoader,
        ".htm": UnstructuredHTMLLoader,
        ".csv": CSVLoader,
        ".json": JSONLoader,
        ".md": UnstructuredMarkdownLoader,
        ".pptx": UnstructuredPowerPointLoader,
        ".xlsx": UnstructuredExcelLoader,
    }
    
    @staticmethod
    def load_file(file_path: str) -> list:
        """加载单个文件"""
        
        ext = os.path.splitext(file_path)[1].lower()
        
        if ext not in DocumentLoader.LOADER_MAPPING:
            raise ValueError(f"不支持的文件格式: {ext}")
        
        loader_class = DocumentLoader.LOADER_MAPPING[ext]
        loader = loader_class(file_path)
        
        return loader.load()
    
    @staticmethod
    def load_directory(
        directory: str,
        glob_pattern: str = "**/*",
        recursive: bool = True
    ) -> list:
        """加载目录中的所有文档"""
        
        documents = []
        
        for ext, loader_class in DocumentLoader.LOADER_MAPPING.items():
            try:
                loader = DirectoryLoader(
                    directory,
                    glob=f"**/*{ext}",
                    loader_cls=loader_class,
                    recursive=recursive
                )
                docs = loader.load()
                documents.extend(docs)
                print(f"加载 {len(docs)}{ext} 文件")
            except Exception as e:
                print(f"加载 {ext} 文件时出错: {e}")
        
        return documents

# 使用示例
# docs = DocumentLoader.load_file("./document.pdf")
# all_docs = DocumentLoader.load_directory("./data")

2. PDF 文档解析

# pdf_parsing.py
from pypdf import PdfReader
import fitz  # PyMuPDF
from pdf2image import convert_from_path
import pytesseract
from PIL import Image

class PDFParser:
    """PDF 解析器"""
    
    def __init__(self, use_ocr: bool = False):
        self.use_ocr = use_ocr
    
    def parse_with_pypdf(self, file_path: str) -> dict:
        """使用 PyPDF 解析(适合文本型 PDF)"""
        
        reader = PdfReader(file_path)
        
        content = {
            "metadata": reader.metadata,
            "pages": [],
            "total_pages": len(reader.pages)
        }
        
        for i, page in enumerate(reader.pages):
            text = page.extract_text()
            content["pages"].append({
                "page_number": i + 1,
                "text": text,
                "links": page.links if hasattr(page, 'links') else []
            })
        
        return content
    
    def parse_with_pymupdf(self, file_path: str) -> dict:
        """使用 PyMuPDF 解析(支持更多功能)"""
        
        doc = fitz.open(file_path)
        
        content = {
            "metadata": doc.metadata,
            "pages": [],
            "total_pages": len(doc)
        }
        
        for page_num in range(len(doc)):
            page = doc[page_num]
            
            # 提取文本
            text = page.get_text()
            
            # 提取图片
            images = []
            for img_index, img in enumerate(page.get_images(), start=1):
                xref = img[0]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                images.append({
                    "index": img_index,
                    "bytes": image_bytes,
                    "ext": base_image["ext"]
                })
            
            # 提取表格
            tables = page.find_tables()
            
            content["pages"].append({
                "page_number": page_num + 1,
                "text": text,
                "images": images,
                "tables": tables
            })
        
        doc.close()
        return content
    
    def parse_with_ocr(self, file_path: str, lang: str = "chi_sim+eng") -> dict:
        """使用 OCR 解析扫描版 PDF"""
        
        # 转换为图片
        images = convert_from_path(file_path)
        
        content = {
            "pages": [],
            "total_pages": len(images)
        }
        
        for i, image in enumerate(images):
            # OCR 识别
            text = pytesseract.image_to_string(image, lang=lang)
            
            content["pages"].append({
                "page_number": i + 1,
                "text": text,
                "image": image
            })
        
        return content
    
    def parse(self, file_path: str, method: str = "auto") -> dict:
        """智能解析 PDF"""
        
        if method == "auto":
            # 先尝试文本解析
            try:
                content = self.parse_with_pypdf(file_path)
                # 检查是否有足够文本
                total_text = sum(len(p["text"]) for p in content["pages"])
                if total_text < 100:  # 文本太少,可能是扫描版
                    print("文本较少,尝试 OCR...")
                    content = self.parse_with_ocr(file_path)
                return content
            except Exception as e:
                print(f"文本解析失败,尝试 OCR: {e}")
                return self.parse_with_ocr(file_path)
        
        elif method == "pypdf":
            return self.parse_with_pypdf(file_path)
        elif method == "pymupdf":
            return self.parse_with_pymupdf(file_path)
        elif method == "ocr":
            return self.parse_with_ocr(file_path)
        else:
            raise ValueError(f"未知的解析方法: {method}")

# 使用示例
# parser = PDFParser()
# content = parser.parse("./document.pdf", method="auto")

3. Word 文档解析

# word_parsing.py
from docx import Document
import docx2txt

class WordParser:
    """Word 文档解析器"""
    
    @staticmethod
    def parse_with_docx(file_path: str) -> dict:
        """使用 python-docx 解析"""
        
        doc = Document(file_path)
        
        content = {
            "paragraphs": [],
            "tables": [],
            "headings": []
        }
        
        # 提取段落
        for para in doc.paragraphs:
            if para.text.strip():
                # 检测标题
                if para.style.name.startswith('Heading'):
                    content["headings"].append({
                        "level": int(para.style.name[-1]) if para.style.name[-1].isdigit() else 0,
                        "text": para.text
                    })
                else:
                    content["paragraphs"].append(para.text)
        
        # 提取表格
        for table in doc.tables:
            table_data = []
            for row in table.rows:
                row_data = [cell.text for cell in row.cells]
                table_data.append(row_data)
            content["tables"].append(table_data)
        
        return content
    
    @staticmethod
    def parse_with_docx2txt(file_path: str) -> str:
        """使用 docx2txt 提取纯文本"""
        
        text = docx2txt.process(file_path)
        return text
    
    @staticmethod
    def extract_structure(file_path: str) -> dict:
        """提取文档结构"""
        
        doc = Document(file_path)
        
        structure = {
            "title": "",
            "sections": []
        }
        
        current_section = None
        
        for para in doc.paragraphs:
            if para.style.name.startswith('Heading'):
                level = int(para.style.name[-1]) if para.style.name[-1].isdigit() else 0
                
                if level == 1:
                    current_section = {
                        "title": para.text,
                        "content": [],
                        "subsections": []
                    }
                    structure["sections"].append(current_section)
                elif current_section:
                    if level == 2:
                        current_section["subsections"].append({
                            "title": para.text,
                            "content": []
                        })
            elif current_section and para.text.strip():
                current_section["content"].append(para.text)
        
        return structure

# 使用示例
# content = WordParser.parse_with_docx("./document.docx")

4. 网页解析

# web_parsing.py
import requests
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
import trafilatura

class WebParser:
    """网页解析器"""
    
    @staticmethod
    def fetch_url(url: str, headers: dict = None) -> str:
        """获取网页内容"""
        
        default_headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        }
        
        if headers:
            default_headers.update(headers)
        
        response = requests.get(url, headers=default_headers, timeout=30)
        response.raise_for_status()
        
        return response.text
    
    @staticmethod
    def parse_with_bs4(html: str, base_url: str = "") -> dict:
        """使用 BeautifulSoup 解析"""
        
        soup = BeautifulSoup(html, 'html.parser')
        
        # 移除脚本和样式
        for script in soup(["script", "style"]):
            script.decompose()
        
        content = {
            "title": soup.title.string if soup.title else "",
            "text": soup.get_text(separator='\n', strip=True),
            "links": [],
            "images": []
        }
        
        # 提取链接
        for link in soup.find_all('a', href=True):
            href = urljoin(base_url, link['href'])
            content["links"].append({
                "url": href,
                "text": link.get_text(strip=True)
            })
        
        # 提取图片
        for img in soup.find_all('img', src=True):
            src = urljoin(base_url, img['src'])
            content["images"].append({
                "url": src,
                "alt": img.get('alt', '')
            })
        
        return content
    
    @staticmethod
    def parse_with_trafilatura(url: str) -> dict:
        """使用 trafilatura 提取正文"""
        
        downloaded = trafilatura.fetch_url(url)
        
        result = trafilatura.extract(
            downloaded,
            output_format="json",
            with_metadata=True,
            include_images=True,
            include_tables=True,
            include_links=True
        )
        
        import json
        return json.loads(result) if result else {}
    
    @staticmethod
    def parse_article(url: str) -> dict:
        """解析文章页面"""
        
        html = WebParser.fetch_url(url)
        
        # 优先使用 trafilatura
        try:
            content = WebParser.parse_with_trafilatura(url)
            if content and content.get('text'):
                return content
        except:
            pass
        
        # 备用方案:BeautifulSoup
        return WebParser.parse_with_bs4(html, url)

# 使用示例
# article = WebParser.parse_article("https://example.com/article")

文本预处理

1. 文本清洗

# text_cleaning.py
import re
import unicodedata

class TextCleaner:
    """文本清洗器"""
    
    @staticmethod
    def remove_extra_whitespace(text: str) -> str:
        """移除多余空白"""
        # 多个空格/制表符/换行符替换为单个
        text = re.sub(r'[\s]+', ' ', text)
        return text.strip()
    
    @staticmethod
    def remove_special_chars(text: str, keep_chars: str = "") -> str:
        """移除特殊字符"""
        # 保留中文、英文、数字、基本标点
        allowed = r'\u4e00-\u9fff\u3000-\u303f\uff00-\uffef'a-zA-Z0-9\s\n\r\t' + re.escape(keep_chars)
        text = re.sub(f'[^{allowed}]', '', text)
        return text
    
    @staticmethod
    def normalize_unicode(text: str) -> str:
        """标准化 Unicode"""
        # NFC 规范化
        return unicodedata.normalize('NFC', text)
    
    @staticmethod
    def remove_html_tags(text: str) -> str:
        """移除 HTML 标签"""
        clean = re.compile('<.*?>')
        return re.sub(clean, '', text)
    
    @staticmethod
    def remove_urls(text: str) -> str:
        """移除 URL"""
        url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
        return re.sub(url_pattern, '', text)
    
    @staticmethod
    def remove_email(text: str) -> str:
        """移除邮箱地址"""
        email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
        return re.sub(email_pattern, '', text)
    
    @staticmethod
    def fix_common_errors(text: str) -> str:
        """修复常见错误"""
        # 替换常见错误字符
        replacements = {
            '': '"',
            '': '"',
            ''': "'",
            ''': "'",
            '': '...',
            '': '-',
            '\\t': ' ',
            '\\n': '\n',
        }
        
        for old, new in replacements.items():
            text = text.replace(old, new)
        
        return text
    
    @classmethod
    def clean(cls, text: str, config: dict = None) -> str:
        """执行完整清洗流程"""
        
        if config is None:
            config = {
                "normalize_unicode": True,
                "remove_html": True,
                "remove_urls": False,
                "remove_email": False,
                "remove_special_chars": False,
                "fix_errors": True,
                "remove_extra_whitespace": True
            }
        
        if config.get("normalize_unicode"):
            text = cls.normalize_unicode(text)
        
        if config.get("remove_html"):
            text = cls.remove_html_tags(text)
        
        if config.get("remove_urls"):
            text = cls.remove_urls(text)
        
        if config.get("remove_email"):
            text = cls.remove_email(text)
        
        if config.get("fix_errors"):
            text = cls.fix_common_errors(text)
        
        if config.get("remove_special_chars"):
            text = cls.remove_special_chars(text)
        
        if config.get("remove_extra_whitespace"):
            text = cls.remove_extra_whitespace(text)
        
        return text

# 使用示例
# cleaned_text = TextCleaner.clean(raw_text)

2. 文本分块策略

# text_chunking.py
from typing import List
import re

class TextChunker:
    """文本分块器"""
    
    @staticmethod
    def fixed_size_chunk(
        text: str,
        chunk_size: int = 500,
        overlap: int = 50
    ) -> List[str]:
        """固定大小分块"""
        
        chunks = []
        start = 0
        
        while start < len(text):
            end = start + chunk_size
            chunk = text[start:end]
            chunks.append(chunk)
            start = end - overlap
        
        return chunks
    
    @staticmethod
    def recursive_chunk(
        text: str,
        chunk_size: int = 500,
        separators: List[str] = None
    ) -> List[str]:
        """递归分块(优先按语义边界)"""
        
        if separators is None:
            separators = ["\n\n", "\n", "", "", " ", ""]
        
        chunks = []
        
        def _split_recursive(text: str, separator_index: int):
            if separator_index >= len(separators):
                # 没有更多分隔符,直接切割
                return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
            
            separator = separators[separator_index]
            
            if separator == "":
                # 字符级分割
                return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
            
            parts = text.split(separator)
            
            result = []
            current_chunk = ""
            
            for part in parts:
                test_chunk = current_chunk + separator + part if current_chunk else part
                
                if len(test_chunk) <= chunk_size:
                    current_chunk = test_chunk
                else:
                    if current_chunk:
                        result.append(current_chunk)
                    # 如果单个部分太长,递归处理
                    if len(part) > chunk_size:
                        result.extend(_split_recursive(part, separator_index + 1))
                        current_chunk = ""
                    else:
                        current_chunk = part
            
            if current_chunk:
                result.append(current_chunk)
            
            return result
        
        return _split_recursive(text, 0)
    
    @staticmethod
    def semantic_chunk(
        text: str,
        embedding_model,
        max_chunk_size: int = 500,
        similarity_threshold: float = 0.8
    ) -> List[str]:
        """语义分块(基于句子相似度)"""
        
        # 分割为句子
        sentences = re.split(r'(?<=[。!?])', text)
        sentences = [s.strip() for s in sentences if s.strip()]
        
        if not sentences:
            return []
        
        chunks = []
        current_chunk = sentences[0]
        current_embedding = embedding_model.embed(current_chunk)
        
        for sentence in sentences[1:]:
            # 检查大小限制
            if len(current_chunk) + len(sentence) > max_chunk_size:
                chunks.append(current_chunk)
                current_chunk = sentence
                current_embedding = embedding_model.embed(sentence)
                continue
            
            # 计算语义相似度
            sentence_embedding = embedding_model.embed(sentence)
            similarity = cosine_similarity([current_embedding], [sentence_embedding])[0][0]
            
            if similarity >= similarity_threshold:
                # 语义相关,合并
                current_chunk += sentence
                # 更新嵌入(简单平均)
                current_embedding = (current_embedding + sentence_embedding) / 2
            else:
                # 语义不相关,新建块
                chunks.append(current_chunk)
                current_chunk = sentence
                current_embedding = sentence_embedding
        
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks
    
    @staticmethod
    def markdown_chunk(text: str) -> List[dict]:
        """Markdown 文档分块(保留结构)"""
        
        chunks = []
        current_section = {"title": "", "content": [], "level": 0}
        
        lines = text.split('\n')
        
        for line in lines:
            # 检测标题
            if line.startswith('#'):
                # 保存当前章节
                if current_section["content"]:
                    chunks.append({
                        "title": current_section["title"],
                        "content": '\n'.join(current_section["content"]),
                        "level": current_section["level"]
                    })
                
                # 开始新章节
                level = len(line.split()[0])
                title = line.lstrip('#').strip()
                current_section = {
                    "title": title,
                    "content": [],
                    "level": level
                }
            else:
                current_section["content"].append(line)
        
        # 保存最后一个章节
        if current_section["content"]:
            chunks.append({
                "title": current_section["title"],
                "content": '\n'.join(current_section["content"]),
                "level": current_section["level"]
            })
        
        return chunks

# 使用示例
# chunks = TextChunker.recursive_chunk(long_text, chunk_size=500)

元数据提取

# metadata_extraction.py
from datetime import datetime
import hashlib

class MetadataExtractor:
    """元数据提取器"""
    
    @staticmethod
    def extract_basic_metadata(text: str, source: str) -> dict:
        """提取基础元数据"""
        
        return {
            "source": source,
            "timestamp": datetime.now().isoformat(),
            "char_count": len(text),
            "word_count": len(text.split()),
            "line_count": len(text.split('\n')),
            "hash": hashlib.md5(text.encode()).hexdigest()[:8]
        }
    
    @staticmethod
    def extract_entities(text: str, ner_model) -> list:
        """提取命名实体"""
        
        entities = ner_model(text)
        return [
            {
                "text": ent["word"],
                "label": ent["entity_group"],
                "start": ent["start"],
                "end": ent["end"]
            }
            for ent in entities
        ]
    
    @staticmethod
    def extract_keywords(text: str, top_k: int = 10) -> list:
        """提取关键词"""
        
        # 使用 TF-IDF 或 TextRank
        from sklearn.feature_extraction.text import TfidfVectorizer
        
        vectorizer = TfidfVectorizer(
            max_features=100,
            stop_words='english',
            ngram_range=(1, 2)
        )
        
        tfidf = vectorizer.fit_transform([text])
        feature_names = vectorizer.get_feature_names_out()
        scores = tfidf.toarray()[0]
        
        # 排序并返回前 k 个
        top_indices = scores.argsort()[-top_k:][::-1]
        
        return [
            {"keyword": feature_names[i], "score": float(scores[i])}
            for i in top_indices
        ]

# 使用示例
# metadata = MetadataExtractor.extract_basic_metadata(text, "document.pdf")

最佳实践

1. 处理流程

# processing_pipeline.py
class DocumentProcessingPipeline:
    """文档处理流水线"""
    
    def __init__(self):
        self.pdf_parser = PDFParser()
        self.word_parser = WordParser()
        self.web_parser = WebParser()
        self.text_cleaner = TextCleaner()
        self.text_chunker = TextChunker()
    
    def process(self, source: str, source_type: str) -> list:
        """处理文档"""
        
        # 步骤1: 加载和解析
        if source_type == "pdf":
            content = self.pdf_parser.parse(source)
            text = "\n\n".join([p["text"] for p in content["pages"]])
        elif source_type == "word":
            content = self.word_parser.parse_with_docx(source)
            text = "\n\n".join(content["paragraphs"])
        elif source_type == "web":
            content = self.web_parser.parse_article(source)
            text = content.get("text", "")
        else:
            raise ValueError(f"不支持的类型: {source_type}")
        
        # 步骤2: 清洗
        cleaned_text = self.text_cleaner.clean(text)
        
        # 步骤3: 分块
        chunks = self.text_chunker.recursive_chunk(cleaned_text)
        
        # 步骤4: 添加元数据
        for i, chunk in enumerate(chunks):
            chunks[i] = {
                "content": chunk,
                "metadata": {
                    "source": source,
                    "chunk_index": i,
                    "total_chunks": len(chunks),
                    **MetadataExtractor.extract_basic_metadata(chunk, source)
                }
            }
        
        return chunks

# 使用示例
# pipeline = DocumentProcessingPipeline()
# chunks = pipeline.process("./doc.pdf", "pdf")

总结

文档处理的关键要点:

  1. 多格式支持:根据文件类型选择合适的解析器
  2. 智能解析:文本型 PDF 和扫描版 PDF 使用不同方法
  3. 深度清洗:移除噪声,保留有效信息
  4. 语义分块:按意义边界分割,而非机械切割
  5. 元数据丰富:为后续检索提供更多信息

工具推荐:

  • PDF: PyMuPDF (fitz) + OCR (pytesseract)
  • Word: python-docx
  • Web: trafilatura
  • 清洗: 正则 + BeautifulSoup
  • 分块: LangChain TextSplitter

相关资源: