引言
通用 Embedding 模型在特定领域(法律、医疗、金融、代码)的表现往往不尽人意。微调(Fine-tuning) 可以让 Embedding 模型更好地理解领域特有的术语和语义关系。
通用模型 → 在你的领域数据上微调 → 领域专用模型
│ │
▼ ▼
"诉讼保全"向量 "诉讼保全"向量更接近
接近"法律诉讼" "财产保全""冻结资产"
本文将从头到尾讲解微调 Embedding 模型的全流程。
一、为什么需要微调?
1.1 通用模型的局限
| 问题 | 表现 | 例子 |
|---|---|---|
| 专业术语混淆 | 相似度不准确 | ”心肌梗死” vs “心脏骤停”区分不清 |
| 领域语境缺失 | 检索不相关结果 | ”苹果”在科技领域 vs 农业领域 |
| 简写识别差 | 无法匹配缩写 | ”RAG” = “检索增强生成” |
| 长尾问题 | 罕见概念检索失败 | 特定产品型号、法规条款 |
1.2 微调收益
微调前:[通用模型] 微调后:[领域专用模型]
检索召回率@10: 78% 检索召回率@10: 91%
准确率: 72% 准确率: 87%
二、微调方法
2.1 对比学习(Contrastive Learning)
最常用的 Embedding 微调方法。核心思想:让相似文本对的向量距离更近,不相似对的距离更远。
anchor: "什么是RAG系统?"
│
├── positive: "RAG(检索增强生成)是一种..." ← 拉近
│
└── negative: "今天天气真好" ← 推远
损失函数(InfoNCE Loss):
import torch.nn.functional as F
def contrastive_loss(
anchor_emb: torch.Tensor, # [batch, dim]
positive_emb: torch.Tensor, # [batch, dim]
negatives_emb: torch.Tensor, # [batch, n_neg, dim]
temperature: float = 0.05
) -> torch.Tensor:
# 正样本相似度
pos_sim = F.cosine_similarity(
anchor_emb, positive_emb, dim=-1
) / temperature
# 负样本相似度
neg_sim = torch.einsum(
'bd,bnd->bn',
anchor_emb, negatives_emb
) / temperature
# InfoNCE loss
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
labels = torch.zeros(logits.size(0), dtype=torch.long)
return F.cross_entropy(logits, labels)
2.2 全量微调 vs LoRA
| 方法 | 可训练参数 | GPU内存需求 | 效果 | 速度 |
|---|---|---|---|---|
| 全量微调 | 100% | ~24GB (large模型) | 最佳 | 慢 |
| LoRA | 0.1-1% | ~8GB | 接近全量 | 快 |
| 前缀微调 | 少量参数 | ~6GB | 一般 | 最快 |
推荐: 数据量 < 10万条时用 LoRA,数据量大且有充足 GPU 时用全量微调。
三、数据准备
3.1 数据格式
每条训练数据是一个三元组 (查询, 相关文档, 不相关文档):
[
{
"query": "RAG系统的检索流程是怎样的?",
"positive": "RAG检索流程包括:查询编码、向量检索、重排序...",
"negative": "Python是一种广泛使用的编程语言..."
},
{
"query": "什么是余弦相似度?",
"positive": "余弦相似度通过计算两个向量夹角的余弦值来衡量相似度...",
"negative": "今天CPU的温度是45度..."
}
]
3.2 数据构建工具
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
def prepare_training_data(
queries: list[str],
positive_docs: list[str],
negative_docs: list[list[str]]
) -> list[InputExample]:
"""
准备 Sentence Transformers 训练数据
"""
examples = []
for q, pos, negs in zip(queries, positive_docs, negative_docs):
for neg in negs:
examples.append(
InputExample(
texts=[q, pos, neg],
label=1.0 # 正样本标记
)
)
return examples
# 硬负样本挖掘可以提升效果
def mine_hard_negatives(
model, queries, corpus, top_k=100
):
"""检索与查询相似但不相关的文档作为硬负样本"""
model.eval()
q_emb = model.encode(queries)
c_emb = model.encode(corpus)
hard_negatives = []
for q_emb_single in q_emb:
scores = cosine_similarity([q_emb_single], c_emb)[0]
# 排除正样本后取 top_k
hard_negatives.append(
[corpus[i] for i in scores.argsort()[-top_k:]]
)
return hard_negatives
3.3 数据量建议
| 场景 | 最少数据量 | 推荐数据量 |
|---|---|---|
| 新领域从零开始 | 500 条 | 5000+ 条 |
| 领域间迁移 | 200 条 | 1000+ 条 |
| 特定任务优化 | 100 条 | 500+ 条 |
四、训练实战
4.1 使用 Sentence Transformers 微调
from sentence_transformers import (
SentenceTransformer,
losses,
models
)
from sentence_transformers.evaluation import (
TripletEvaluator
)
# 1. 加载预训练模型
base_model = SentenceTransformer(
'BAAI/bge-large-zh-v1.5'
)
# 2. 准备数据
train_examples = prepare_training_data(
queries, positive_docs, negative_docs
)
train_dataloader = DataLoader(
train_examples,
shuffle=True,
batch_size=32
)
# 3. 配置损失函数
train_loss = losses.TripletLoss(
model=base_model,
distance_metric=losses.TripletDistanceMetric.COSINE,
triplet_margin=0.5
)
# 4. 配置评估器
evaluator = TripletEvaluator(
anchors=eval_queries,
positives=eval_positives,
negatives=eval_negatives,
name='domain-eval'
)
# 5. 开始训练
base_model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=5,
warmup_steps=100,
evaluation_steps=500,
output_path='./fine-tuned-bge-zh',
save_best_model=True,
use_amp=True # 混合精度训练
)
4.2 使用 LoRA 微调(低资源场景)
from peft import LoraConfig, get_peft_model
from transformers import AutoModel
# 配置 LoRA
lora_config = LoraConfig(
r=16, # LoRA 秩
lora_alpha=32, # 缩放参数
target_modules=["query", "value"], # 目标模块
lora_dropout=0.1,
bias="none",
)
# 应用 LoRA
base_transformer = AutoModel.from_pretrained(
'BAAI/bge-large-zh-v1.5'
)
lora_model = get_peft_model(
base_transformer, lora_config
)
五、评估与验证
5.1 微调前后效果对比
def compare_models(
base_model_name: str,
fine_tuned_model_name: str,
test_queries: list[str],
test_corpus: list[str],
relevant_indices: list[list[int]]
):
base = SentenceTransformer(base_model_name)
tuned = SentenceTransformer(fine_tuned_model_name)
for name, model in [("Base", base), ("Fine-tuned", tuned)]:
q_emb = model.encode(test_queries)
c_emb = model.encode(test_corpus)
recall = evaluate_recall(
q_emb, c_emb, relevant_indices, k=10
)
print(f"{name} Recall@10: {recall:.2%}")
5.2 关键监控指标
| 指标 | 说明 | 期望 |
|---|---|---|
| 训练 Loss | 对比损失下降趋势 | 持续下降 |
| 评估 Recall | 检索召回率 | 提升 > 5% |
| 正负样本间隔 | 正负样本相似度差值 | 增大 |
| 过拟合检测 | 训练 Loss vs 验证 Loss | 差距 < 20% |
六、常见陷阱与解决方案
| 陷阱 | 表现 | 解决方法 |
|---|---|---|
| 过拟合 | 训练集效果好但测试集差 | 增加数据量、降低训练轮次 |
| 灾难性遗忘 | 模型丢失通用能力 | 混入通用数据训练 |
| 批次负样本不足 | 对比学习效果差 | 增大 batch_size 或使用缓存负样本 |
| 标签噪音 | 正负样本标记错误 | 数据质量审核 + 降噪处理 |
七、总结
微调 Embedding 模型的关键步骤:
- 收集领域数据 — 至少 500 条高质量三元组
- 挖掘硬负样本 — 提升训练效果的关键技巧
- 选择微调策略 — 数据充足用全量,否则用 LoRA
- 监控过拟合 — 定期在验证集上评估
- 对比基线 — 始终与微调前的模型对比
什么时候值得微调:
- 领域术语密度高(法律、医疗、金融)
- 通用模型检索效果 < 80% Recall
- 有足够的标注数据或预算进行标注
相关资源: