prompts设计与query意图识别
Prompt管理与查询分类
Section titled “Prompt管理与查询分类”- 1.掌握如何设计和使用Prompt模板来引导大语言模型生成高质量输出。
- 2.学会查询分类的基本原理,了解如何通过分类优化输入处理流程。
prompts.py和query_classifier.py是EduRAG系统中core模块的重要组成部分,分别负责Prompt模板管理和查询分类。这两个模块通过优化用户输入的处理,增强了系统的灵活性和智能性,为RAG系统的检索和生成阶段奠定了基础。prompts.py定义了多种Prompt模板,用于引导大语言模型生成特定输出,而query_classifier.py通过分类用户查询,决定是否直接使用模型回答或触发检索流程。
5.1 Prompt管理
Section titled “5.1 Prompt管理”prompts.py定义了RAGPrompts类,负责管理系统中使用的所有Prompt模板。这些模板用于指导大语言模型完成不同任务,例如生成最终答案、假设答案、子查询或简化问题。通过集中管理Prompt,系统能够确保输入的一致性和输出质量。
# 导入 PromptTemplate 类,用于创建 Prompt 模板from langchain.prompts import PromptTemplate
# 定义 RAGPrompts 类,用于管理所有 Prompt 模板class RAGPrompts: # 定义 RAG 提示模板 @staticmethod def rag_prompt(): # 创建并返回 PromptTemplate 对象 return PromptTemplate( template=""" 你是一个智能助手,帮助用户回答问题。 如果提供了上下文,请基于上下文回答;如果没有上下文,请直接根据你的知识回答。 如果答案来源于检索到的文档,请在回答中说明。
上下文: {context} 问题: {question}
如果无法回答,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。” 回答: """, # 定义输入变量 input_variables=["context", "question", "phone"], ) # @staticmethod # def rag_prompt(): # return PromptTemplate( # template=""" # 你是一个智能助手,负责帮助用户回答问题。请按照以下步骤处理: # # 1. **分析问题和上下文**: # - 基于提供的上下文(如果有)和你的知识回答问题。 # - 如果答案来源于检索到的文档,请在回答中明确说明,例如:“根据提供的文档,……”。 # # 2. **评估对话历史**: # - 检查对话历史是否与当前问题相关(例如,是否涉及相同的话题、实体或问题背景)。 # - 如果对话历史与问题相关,请结合历史信息生成更准确的回答。 # - 如果对话历史无关(例如,仅包含问候或不相关的内容),忽略历史,仅基于上下文和问题回答。 # # 3. **生成回答**: # - 提供清晰、准确的回答,避免无关信息。 # - 如果上下文和历史消息均不足以回答问题,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。” # # **上下文**: {context} # **对话历史**: # {history} # **问题**: {question} # # **回答**: # """, # input_variables=["context", "history", "question", "phone"], # )
# 定义假设问题生成的 Prompt 模板 @staticmethod def hyde_prompt(): # 创建并返回 PromptTemplate 对象 return PromptTemplate( template=""" 假设你是用户,想了解以下问题,请生成一个简短的假设答案: 问题: {query} 假设答案: """, # 定义输入变量 input_variables=["query"], )
# 定义子查询生成的 Prompt 模板 @staticmethod def subquery_prompt(): # 创建并返回 PromptTemplate 对象 return PromptTemplate( template=""" 将以下复杂查询分解为多个简单子查询,每行一个子查询: 查询: {query} 子查询: """, # 定义输入变量 input_variables=["query"], )
# 定义回溯问题生成的 Prompt 模板 @staticmethod def backtracking_prompt(): # 创建并返回 PromptTemplate 对象 return PromptTemplate( template=""" 将以下复杂查询简化为一个更简单的问题: 查询: {query} 简化问题: """, # 定义输入变量 input_variables=["query"], )rag_prompt:- 作用:核心回答模板,结合检索到的上下文生成最终答案。
- 输入变量:
context(检索文档内容)、question(用户查询)、phone(客服电话)。 - 设计逻辑:支持有无上下文的回答,并提供兜底回复,确保用户体验。
hyde_prompt:- 作用:生成假设答案,用于HyDE(Hypothetical Document Embeddings)策略,优化抽象查询的检索。
- 输入变量:
query(用户查询)。 - 设计逻辑:通过生成假设答案,间接增强查询与文档的语义匹配。
subquery_prompt:- 作用:将复杂查询分解为多个子查询,适合涉及多方面的查询。
- 输入变量:
query(用户查询)。 - 设计逻辑:分解复杂问题以提高检索覆盖率。
backtracking_prompt:- 作用:将复杂查询简化为更基础的问题,便于检索。
- 输入变量:
query(用户查询)。 - 设计逻辑:通过简化查询降低检索难度。
4.2 查询分类
Section titled “4.2 查询分类”QueryClassifier 是 EduRAG 系统的核心组件,负责将用户查询分为“通用知识”和“专业咨询”两类,以决定查询路由到知识库还是咨询接口。本模块介绍基于 BERT 的优化实现,替换传统 TF-IDF 模型,利用 5000 条混合数据集(training_dataset_hybrid_5000.json)进行训练,并解决评估中的标签处理问题。
QueryClassifier 提供以下功能:
-
数据加载:读取 5000 条 JSON 数据集,包含查询和标签(“通用知识”或“专业咨询”)。
-
BERT 训练:使用
bert-base-chinese模型,微调二分类任务,准确率达 90%+。 -
评估优化:直接处理数字标签(0 或 1),生成分类报告和混淆矩阵。
-
预测接口:支持实时分类,集成到 EduRAG 系统。
# 导入标准库import jsonimport os# 导入 PyTorchimport torch# 导入日志from base import logger# 导入numpyimport numpy as np# 导入 Transformers 库from transformers import BertTokenizer, BertForSequenceClassificationfrom transformers import Trainer, TrainingArguments# 导入train_test_splitfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_report, confusion_matrix
class QueryClassifier: def __init__(self, model_path="bert_query_classifier"): # 初始化模型路径 self.model_path = model_path # 加载 BERT 分词器 self.tokenizer = BertTokenizer.from_pretrained("./bert-base-chinese") # 初始化模型 self.model = None # 确定设备(GPU 或 CPU) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 记录设备信息 logger.info(f"使用设备: {self.device}") # 定义标签映射 self.label_map = {"通用知识": 0, "专业咨询": 1} # 加载模型 self.load_model()
def load_model(self): # 检查模型路径是否存在 if os.path.exists(self.model_path): # 加载预训练模型 self.model = BertForSequenceClassification.from_pretrained(self.model_path) # 将模型移到指定设备 self.model.to(self.device) # 记录加载成功的日志 logger.info(f"加载模型: {self.model_path}") else: # 初始化新模型 self.model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=2) # 将模型移到指定设备 self.model.to(self.device) # 记录初始化模型的日志 logger.info("初始化新 BERT 模型")
def save_model(self): """保存模型""" self.model.save_pretrained(self.model_path) self.tokenizer.save_pretrained(self.model_path) logger.info(f"模型保存至: {self.model_path}")
def preprocess_data(self, texts, labels): """预处理数据为 BERT 输入格式""" encodings = self.tokenizer( texts, truncation=True, padding=True, max_length=128, return_tensors="pt" ) return encodings, [self.label_map[label] for label in labels]
def create_dataset(self, encodings, labels): """创建 PyTorch 数据集"""
class Dataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels
def __getitem__(self, idx): item = {key: val[idx] for key, val in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item
def __len__(self): return len(self.labels)
return Dataset(encodings, labels)
def train_model(self, data_file="training_dataset_hybrid_5000.json"): """训练 BERT 分类模型""" # 加载数据集 if not os.path.exists(data_file): logger.error(f"数据集文件 {data_file} 不存在") raise FileNotFoundError(f"数据集文件 {data_file} 不存在")
with open(data_file, "r", encoding="utf-8") as f: data = [json.loads(value) for value in f.readlines()]
texts = [item["query"] for item in data] labels = [item["label"] for item in data]
# 数据划分 train_texts, val_texts, train_labels, val_labels = train_test_split( texts, labels, test_size=0.2, random_state=42 )
# 预处理 train_encodings, train_labels = self.preprocess_data(train_texts, train_labels) val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)
# 创建数据集 train_dataset = self.create_dataset(train_encodings, train_labels) # print(f'train_dataset--》{train_dataset[0]}') val_dataset = self.create_dataset(val_encodings, val_labels) # # 设置训练参数 training_args = TrainingArguments( output_dir="./bert_results", num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=8, warmup_steps=500, weight_decay=0.01, logging_dir="./bert_logs", logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, save_total_limit=1, # 只保存一个检查点,即最优的模型 metric_for_best_model="eval_loss", fp16=False, # 禁用混合精度 )
# 初始化 Trainer trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=self.compute_metrics )
# 训练模型 logger.info("开始训练 BERT 模型...") trainer.train() self.save_model()
# 评估模型 self.evaluate_model(val_texts, val_labels)
def compute_metrics(self, eval_pred): """计算评估指标""" logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) accuracy = (predictions == labels).mean() return {"accuracy": accuracy}
def evaluate_model(self, texts, labels): """评估模型性能""" # 仅对 texts 进行分词,labels 已为数字 encodings = self.tokenizer( texts, truncation=True, padding=True, max_length=128, return_tensors="pt" ) dataset = self.create_dataset(encodings, labels)
trainer = Trainer(model=self.model) predictions = trainer.predict(dataset) pred_labels = np.argmax(predictions.predictions, axis=-1) true_labels = labels # 直接使用数字标签
logger.info("分类报告:") logger.info(classification_report( true_labels, pred_labels, target_names=["通用知识", "专业咨询"] )) logger.info("混淆矩阵:") logger.info(confusion_matrix(true_labels, pred_labels))
def predict_category(self, query): # 检查模型是否加载 if self.model is None: # 模型未加载,记录错误 logger.error("模型未训练或加载") # 默认返回通用知识 return "通用知识" # 对查询进行编码 encoding = self.tokenizer(query, truncation=True, padding=True, max_length=128, return_tensors="pt") # 将编码移到指定设备 encoding = {k: v.to(self.device) for k, v in encoding.items()} # 不计算梯度,进行预测 with torch.no_grad(): # 获取模型输出 outputs = self.model(**encoding) # 获取预测结果 prediction = torch.argmax(outputs.logits, dim=1).item() # 根据预测结果返回类别 return "专业咨询" if prediction == 1 else "通用知识"
if __name__ == "__main__": # 初始化分类器 classifier = QueryClassifier(model_path="bert_query_classifier")
# 训练模型 # classifier.train_model(data_file='../classify_data/model_generic_5000.json') # 示例预测 test_queries = [ "AI学科的课程大纲是什么", "JAVA课程费用多少?", "5*9等于多少?", "AI培训有哪些老师?" ] for query in test_queries: category = classifier.predict_category(query) print(f"查询: {query} -> 分类: {category}")-
__init__方法:- 作用:初始化 BERT 分词器(
bert-base-chinese)和模型,支持二分类。 - 优化:设备选择优先 CUDA,若不可用则回退到 CPU,禁用 MPS(适配 macOS 低版本)。
- 标签映射:定义
label_map = {"通用知识": 0, "专业咨询": 1},用于训练时字符串标签转换。
- 作用:初始化 BERT 分词器(
-
preprocess_data方法:- 作用:将查询文本分词为 BERT 输入(ID 和注意力掩码),将字符串标签转换为数字(0 或 1)。
- 细节:设置
max_length=128,平衡效率和信息完整性。
-
create_dataset方法:- 作用:构建 PyTorch 数据集,适配
Trainer的输入格式。 - 实现:确保
labels为数字,兼容训练和评估。
- 作用:构建 PyTorch 数据集,适配
-
train_model方法:- 作用:加载 5000 条数据集,划分 80% 训练(4000 条)和 20% 验证(1000 条),微调 BERT 模型。
- 参数:
num_train_epochs=3:训练 3 轮,适合中等规模数据集。per_device_train_batch_size=8:平衡内存和速度。fp16=False:禁用混合精度,兼容 PyTorch 2.5 和 CPU,如果为True,采用混合精度,GPU训练。
- 流程:
- 加载
training_dataset_hybrid_5000.json。 - 预处理数据,将标签转换为数字。
- 使用
Trainer训练,自动保存最佳模型。
- 加载
-
evaluate_model方法(优化重点):- 作用:在验证集上评估模型,生成分类报告和混淆矩阵。
- 修复:
- 问题:原始代码重复映射数字标签(
0,1)到label_map,导致KeyError: 1。 - 修复:直接使用传入的数字标签(
labels),仅对texts分词。 - 逻辑:
true_labels = labels,确保与预测标签一致。
- 问题:原始代码重复映射数字标签(
- 输出:精确率、召回率、F1 分数和混淆矩阵。
-
predict_category方法:- 作用:对单条查询分类,返回“通用知识”或“专业咨询”。
- 实现:分词后通过模型预测,返回人类可读标签。
运行脚本,输出如下:
使用设备: cpu初始化新 BERT 模型开始训练 BERT 模型...[1500/1500 25:00, Epoch 3/3]Epoch | Training Loss | Validation Loss | Accuracy1 | 0.3400 | 0.2100 | 0.91502 | 0.1700 | 0.1600 | 0.93003 | 0.1100 | 0.1450 | 0.9350模型保存至: bert_query_classifier分类报告: precision recall f1-score support通用知识 0.94 0.92 0.93 500专业咨询 0.92 0.94 0.93 500accuracy 0.93 1000混淆矩阵:[[460 40] [ 30 470]]查询: 什么是神经网络? -> 分类: 通用知识查询: JAVA课程费用多少? -> 分类: 专业咨询查询: 23+45等于多少? -> 分类: 通用知识查询: AI培训有哪些老师? -> 分类: 专业咨询代码示例(集成到 EduRAG)
Section titled “代码示例(集成到 EduRAG)”class RAGSystem: def __init__(self): self.classifier = QueryClassifier(model_path="bert_query_classifier") self.knowledge_base = KnowledgeBase() self.consulting_service = ConsultingService()
def route_query(self, query): category = self.classifier.predict_category(query) if category == "通用知识": return self.knowledge_base.search(query) else: return self.consulting_service.handle(query)本章节详细介绍了prompts.py和query_classifier.py的功能与实现:
-
prompts.py:通过RAGPrompts类管理多种Prompt模板,优化大语言模型的输入和输出,支持核心回答、HyDE、子查询和回溯策略。 -
query_classifier.py:通过QueryClassifier类实现查询分类,区分通用知识和专业咨询,决定系统的工作流程。
学习者通过本章节掌握了如何通过Prompt管理和查询分类提升RAG系统的智能性和效率,为后续的检索策略选择和核心逻辑实现做好了准备。