跳转到内容

数据集生成与优化(扩展资料)

  • 1.掌握如何结合规则模板和 Qwen-Plus 模型生成高质量查询数据集。
  • 2.通过 tqdm 进度条监控生成进度,并实现分阶段数据保存。

generate_query_dataset_hybrid.py 是 EduRAG 系统中用于生成训练数据集的核心脚本,旨在为 QueryClassifier 提供 6000 条高质量数据(“通用知识”和“专业咨询”各 3000 条)。通过规则模板和 Qwen-Plus 模型的混合生成,结合进度条和分阶段保存,本脚本确保了生成效率和数据可靠性。本章节将详细讲解其功能、实现和应用。

generate_query_dataset_hybrid.py 提供以下功能:

  • 规则生成:基于模板和同义词替换生成 3000 条数据(“通用知识”和“专业咨询”各 1500 条)。
  • 大模型生成:利用 Qwen-Plus 生成 3000 条自然语言查询(各 1500 条)。
  • 进度监控:通过 tqdm 进度条可视化每个生成阶段。
  • 分阶段保存:每生成 1500 条数据保存一次,最终合并保存完整数据集。
  • 数据集输出:生成均衡的 6000 条数据,保存为 JSON 文件。
generate_query_dataset_hybrid.py
import json
import random
import re
import os
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm
import time
# 加载环境变量
load_dotenv()
# 初始化Qwen-Plus客户端
client = OpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url=os.getenv("DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
)
# 同义词词典
synonym_dict = {
"什么": ["什么", "啥是", "如何解释", "具体是什么", "请解释"],
"课程": ["课程", "培训课程", "课程安排", "教学内容", "课"],
"学费": ["学费", "费用", "报名费", "学习费用", "价格", "收费"],
"大纲": ["大纲", "课程内容", "教学计划", "讲义"],
"师资": ["师资", "教师团队", "讲师阵容", "师资力量", "老师", "讲师"],
"培训": ["培训", "辅导", "学习计划", "教育课程"],
"在哪里": ["在哪里", "位于何处", "设在何地", "在哪"],
"介绍": ["介绍", "说明", "讲解", "概述", "讲讲", "说说"],
"请问": ["请问", "能否告知", "是否可以告诉我", "麻烦说下"],
"原理": ["原理", "基本思想", "工作机制"],
"写一个": ["写一个", "编写一个", "生成一个", "创建一个"],
"等于多少": ["等于多少", "是多少", "结果是啥", "得多少"],
"如何": ["如何", "怎样", "咋样"],
"需要": ["需要", "要", "得有", "必须具备"],
"基础": ["基础", "前提", "背景", "基本知识"]
}
def apply_synonym_variation(text, replace_prob=0.5):
"""对文本进行同义词随机替换"""
for word, synonyms in synonym_dict.items():
pattern = r"\b" + re.escape(word) + r"\b"
def repl(match):
if random.random() < replace_prob:
return random.choice(synonyms)
return match.group(0)
text = re.sub(pattern, repl, text)
return text
# 规则生成部分
def generate_generic_query_rule():
"""规则生成通用知识查询"""
templates = [
"什么{concept}?",
"{concept}的定义是什么?",
"请解释{concept}的原理。",
"如何运用{concept}?",
"计算{num1}+{num2}等于多少?",
"写一个{lang}的{func}函数",
"为什么{thing}是{state}?"
]
concepts = ["AI", "Transformer模型", "Python", "递归", "算法复杂度", "数据结构", "机器学习"]
langs = ["Python", "Java", "C++"]
funcs = ["排序", "计算", "打印"]
things = ["太阳", "水", "风"]
states = ["热的", "流动的", "无形的"]
nums = list(range(1, 100))
t = random.choice(templates)
if "{concept}" in t:
replacements = {"concept": random.choice(concepts)}
elif "{num1}" in t:
replacements = {"num1": random.choice(nums), "num2": random.choice(nums)}
elif "{lang}" in t:
replacements = {"lang": random.choice(langs), "func": random.choice(funcs)}
elif "{thing}" in t:
replacements = {"thing": random.choice(things), "state": random.choice(states)}
else:
replacements = {}
query = t.format(**replacements)
return apply_synonym_variation(query)
def generate_professional_query_rule():
"""规则生成专业咨询查询"""
templates = [
"请问{subject}课程的学费是多少?",
"{subject}的课程大纲是什么?",
"{subject}培训的学习周期有多长?",
"请介绍一下{subject}培训的主要项目内容。",
"请问{subject}培训地点在哪里?"
]
subjects = ["JAVA", "AI", "测试", "Web前端", "Python", "大数据", "DevOps"]
t = random.choice(templates)
query = t.format(subject=random.choice(subjects))
return apply_synonym_variation(query)
# 大模型生成部分
def generate_with_qwen(prompt):
"""调用Qwen-Plus生成查询,添加超时控制"""
try:
completion = client.chat.completions.create(
model="qwen-plus",
messages=[{"role": "user", "content": prompt}],
temperature=0.9,
timeout=10
)
return completion.choices[0].message.content.strip()
except Exception as e:
print(f"Qwen-Plus调用失败: {e}")
return None
def generate_generic_query_qwen():
"""使用Qwen-Plus生成通用知识查询"""
prompt = """
你是一个用户,生成一个“通用知识”类的查询,涉及数学计算、代码生成/纠错、概念与原理或常识性问题。
示例:
- “3+5等于多少?”
- “写一个Python排序函数”
- “什么是神经网络?”
- “太阳为什么是热的?”
请生成一个类似的查询,直接返回查询文本,不要多余说明。
"""
return generate_with_qwen(prompt)
def generate_professional_query_qwen():
"""使用Qwen-Plus生成专业咨询查询"""
prompt = """
你是一个用户,生成一个“专业咨询”类的查询,涉及IT教育培训(如课程详情、师资、费用、周期、地点等)。
示例:
- “JAVA课程费用多少?”
- “AI培训有哪些老师?”
- “测试课程什么时候开课?”
请生成一个类似的查询,直接返回查询文本,不要多余说明。
"""
return generate_with_qwen(prompt)
# 保存数据集
def save_dataset(dataset, filename, stage_name):
"""保存数据集到指定文件"""
with open(filename, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
print(f"{stage_name}:已保存 {len(dataset)} 条数据到 {filename}")
def generate_training_dataset(total_samples=6000):
"""生成6000条训练数据,规则和大模型各占一半,每1500条保存"""
num_per_category = total_samples // 2 # 3000
num_rule = num_per_category // 2 # 1500
num_qwen = num_per_category - num_rule # 1500
generic_samples = []
professional_samples = []
generic_set = set()
professional_set = set()
# 规则生成 - 通用知识
print("生成规则通用知识数据...")
with tqdm(total=num_rule, desc="Rule-based Generic") as pbar:
while len(generic_samples) < num_rule:
q = generate_generic_query_rule()
if q not in generic_set:
generic_set.add(q)
generic_samples.append({"query": q, "label": "通用知识"})
pbar.update(1)
save_dataset(generic_samples, "rule_generic_1500.json", "规则通用知识")
# 规则生成 - 专业咨询
print("生成规则专业咨询数据...")
with tqdm(total=num_rule, desc="Rule-based Professional") as pbar:
while len(professional_samples) < num_rule:
q = generate_professional_query_rule()
if q not in professional_set:
professional_set.add(q)
professional_samples.append({"query": q, "label": "专业咨询"})
pbar.update(1)
save_dataset(professional_samples, "rule_professional_1500.json", "规则专业咨询")
# Qwen-Plus生成 - 通用知识
print("生成Qwen-Plus通用知识数据...")
with tqdm(total=num_qwen, desc="Qwen-based Generic") as pbar:
while len(generic_samples) < num_per_category:
q = generate_generic_query_qwen()
if q and q not in generic_set:
generic_set.add(q)
generic_samples.append({"query": q, "label": "通用知识"})
pbar.update(1)
time.sleep(0.5) # 避免 API 限流
save_dataset(generic_samples, "generic_3000.json", "通用知识(规则+Qwen)")
# Qwen-Plus生成 - 专业咨询
print("生成Qwen-Plus专业咨询数据...")
with tqdm(total=num_qwen, desc="Qwen-based Professional") as pbar:
while len(professional_samples) < num_per_category:
q = generate_professional_query_qwen()
if q and q not in professional_set:
professional_set.add(q)
professional_samples.append({"query": q, "label": "专业咨询"})
pbar.update(1)
time.sleep(0.5) # 避免 API 限流
save_dataset(professional_samples, "professional_3000.json", "专业咨询(规则+Qwen)")
# 合并并混洗
dataset = generic_samples + professional_samples
random.shuffle(dataset)
final_filename = "training_dataset_hybrid_6000.json"
save_dataset(dataset, final_filename, "最终数据集")
return dataset
if __name__ == "__main__":
dataset = generate_training_dataset(total_samples=6000)
print(f"成功生成 {len(dataset)} 条训练数据,保存在 training_dataset_hybrid_6000.json 文件中。")
# 输出前10条作为示例
for item in dataset[:10]:
print(json.dumps(item, ensure_ascii=False))
  1. apply_synonym_variation
    • 作用:为规则生成查询增加多样性,通过同义词替换(如“什么” -> “啥是”)。
    • 逻辑:正则表达式匹配词边界,50% 概率替换。
  2. generate_generic_query_rulegenerate_professional_query_rule
    • 作用:基于模板生成初始数据。
    • 设计:覆盖数学计算、代码生成、概念问题和 IT 培训场景。
  3. generate_with_qwen
    • 作用:封装 Qwen-Plus API 调用,生成自然语言查询。
    • 参数temperature=0.9 确保多样性,timeout=10 防止卡顿。
  4. generate_generic_query_qwengenerate_professional_query_qwen
    • 作用:通过精心设计的 Prompt 指导 Qwen-Plus 生成符合类别的查询。
    • 逻辑:提供示例,输出简洁的查询文本。
  5. save_dataset
    • 作用:统一保存数据集,显示阶段名称和数据量。
    • 实现:支持 JSON 格式,保存中间和最终结果。
  6. generate_training_dataset
    • 作用:整合规则和大模型生成流程。
    • 流程
      • 规则生成 1500 条“通用知识”,保存。
      • 规则生成 1500 条“专业咨询”,保存。
      • Qwen-Plus 生成 1500 条“通用知识”,累计 3000 条保存。
      • Qwen-Plus 生成 1500 条“专业咨询”,累计 3000 条保存。
      • 合并混洗 6000 条,保存最终数据集。
    • 进度条tqdm 实时显示每个阶段的生成进度。
  • 混合生成:规则生成高效可控,Qwen-Plus 生成自然真实。
  • 分阶段保存:每 1500 条保存一次,支持断点恢复。
  • 进度监控tqdm 提供直观反馈,优化用户体验。

运行脚本时,输出类似:

生成规则通用知识数据...
Rule-based Generic: 100%|██████████| 1500/1500 [00:02<00:00, 750it/s]
规则通用知识:已保存 1500 条数据到 rule_generic_1500.json
生成规则专业咨询数据...
Rule-based Professional: 100%|██████████| 1500/1500 [00:02<00:00, 700it/s]
规则专业咨询:已保存 1500 条数据到 rule_professional_1500.json
生成Qwen-Plus通用知识数据...
Qwen-based Generic: 100%|██████████| 1500/1500 [05:00<00:00, 5it/s]
通用知识(规则+Qwen):已保存 3000 条数据到 generic_3000.json
生成Qwen-Plus专业咨询数据...
Qwen-based Professional: 100%|██████████| 1500/1500 [05:05<00:00, 4.9it/s]
专业咨询(规则+Qwen):已保存 3000 条数据到 professional_3000.json
最终数据集:已保存 6000 条数据到 training_dataset_hybrid_6000.json
成功生成 6000 条训练数据,保存在 training_dataset_hybrid_6000.json 文件中。

  1. 规则生成(3000 条)
    • 生成 1500 条“通用知识”数据,保存为 rule_generic_1500.json
    • 生成 1500 条“专业咨询”数据,保存为 rule_professional_1500.json.
    • 使用模板和同义词替换,确保多样性。
  2. Qwen-Plus 生成(3000 条)
    • 生成 1500 条“通用知识”数据,累计 3000 条保存为 generic_3000.json
    • 生成 1500 条“专业咨询”数据,累计 3000 条保存为 professional_3000.json.
    • 通过 Prompt 控制生成质量。
  3. 数据整合
    • 合并 6000 条数据,随机混洗。
    • 保存为 training_dataset_hybrid_6000.json.
core/query_classifier.py
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
import joblib
class QueryClassifier:
def train_model(self):
with open("training_dataset_hybrid_6000.json", "r", encoding="utf-8") as f:
data = json.load(f)
texts = [item["query"] for item in data]
labels = [item["label"] for item in data]
self.model = Pipeline([
("tfidf", TfidfVectorizer()),
("classifier", MultinomialNB()),
])
self.model.fit(texts, labels)
joblib.dump(self.model, "query_classifier_model.pkl")
print("模型训练完成并保存")

  • 总数:6000 条(“通用知识” 3000 条,“专业咨询” 3000 条)。
  • 来源:规则生成 50%(3000 条),Qwen-Plus 生成 50%(3000 条)。
  • 多样性
    • 规则生成:模板和同义词替换覆盖多种场景。
    • Qwen-Plus 生成:自然语言查询贴近真实用户输入。
  • 保存机制:每 1500 条保存,支持断点续传。
  • 可视化tqdm 进度条提供实时反馈。
  • 高效性:简化生成流程,规则生成瞬时完成。
  • 可靠性:分阶段保存确保数据安全。
  • 分类性能:均衡数据集提升 QueryClassifier 准确性。
  • 用户体验:进度条和保存日志增强交互性。

本章节详细介绍了优化后的 generate_query_dataset_hybrid.py: - 功能:混合生成 6000 条数据,规则和 Qwen-Plus 各占一半。 - 实现:通过模板生成 3000 条,Qwen-Plus 生成 3000 条,每 1500 条保存,进度条监控。 - 作用:为 QueryClassifier 提供高质量数据,优化 EduRAG 系统分类能力。

学习者掌握了高效的数据生成和保存方法,能够为 RAG 系统提供可靠支持。