数据集生成与优化(扩展资料)
发布于 2025-12-31
数据集生成与优化
Section titled “数据集生成与优化”- 1.掌握如何结合规则模板和 Qwen-Plus 模型生成高质量查询数据集。
- 2.通过 tqdm 进度条监控生成进度,并实现分阶段数据保存。
generate_query_dataset_hybrid.py 是 EduRAG 系统中用于生成训练数据集的核心脚本,旨在为 QueryClassifier 提供 6000 条高质量数据(“通用知识”和“专业咨询”各 3000 条)。通过规则模板和 Qwen-Plus 模型的混合生成,结合进度条和分阶段保存,本脚本确保了生成效率和数据可靠性。本章节将详细讲解其功能、实现和应用。
1.1 数据生成与优化
Section titled “1.1 数据生成与优化”generate_query_dataset_hybrid.py 提供以下功能:
- 规则生成:基于模板和同义词替换生成 3000 条数据(“通用知识”和“专业咨询”各 1500 条)。
- 大模型生成:利用 Qwen-Plus 生成 3000 条自然语言查询(各 1500 条)。
- 进度监控:通过
tqdm进度条可视化每个生成阶段。 - 分阶段保存:每生成 1500 条数据保存一次,最终合并保存完整数据集。
- 数据集输出:生成均衡的 6000 条数据,保存为 JSON 文件。
import jsonimport randomimport reimport osfrom openai import OpenAIfrom dotenv import load_dotenvfrom tqdm import tqdmimport time
# 加载环境变量load_dotenv()
# 初始化Qwen-Plus客户端client = OpenAI( api_key=os.getenv("DASHSCOPE_API_KEY"), base_url=os.getenv("DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),)
# 同义词词典synonym_dict = { "什么": ["什么", "啥是", "如何解释", "具体是什么", "请解释"], "课程": ["课程", "培训课程", "课程安排", "教学内容", "课"], "学费": ["学费", "费用", "报名费", "学习费用", "价格", "收费"], "大纲": ["大纲", "课程内容", "教学计划", "讲义"], "师资": ["师资", "教师团队", "讲师阵容", "师资力量", "老师", "讲师"], "培训": ["培训", "辅导", "学习计划", "教育课程"], "在哪里": ["在哪里", "位于何处", "设在何地", "在哪"], "介绍": ["介绍", "说明", "讲解", "概述", "讲讲", "说说"], "请问": ["请问", "能否告知", "是否可以告诉我", "麻烦说下"], "原理": ["原理", "基本思想", "工作机制"], "写一个": ["写一个", "编写一个", "生成一个", "创建一个"], "等于多少": ["等于多少", "是多少", "结果是啥", "得多少"], "如何": ["如何", "怎样", "咋样"], "需要": ["需要", "要", "得有", "必须具备"], "基础": ["基础", "前提", "背景", "基本知识"]}
def apply_synonym_variation(text, replace_prob=0.5): """对文本进行同义词随机替换""" for word, synonyms in synonym_dict.items(): pattern = r"\b" + re.escape(word) + r"\b" def repl(match): if random.random() < replace_prob: return random.choice(synonyms) return match.group(0) text = re.sub(pattern, repl, text) return text
# 规则生成部分def generate_generic_query_rule(): """规则生成通用知识查询""" templates = [ "什么{concept}?", "{concept}的定义是什么?", "请解释{concept}的原理。", "如何运用{concept}?", "计算{num1}+{num2}等于多少?", "写一个{lang}的{func}函数", "为什么{thing}是{state}?" ] concepts = ["AI", "Transformer模型", "Python", "递归", "算法复杂度", "数据结构", "机器学习"] langs = ["Python", "Java", "C++"] funcs = ["排序", "计算", "打印"] things = ["太阳", "水", "风"] states = ["热的", "流动的", "无形的"] nums = list(range(1, 100))
t = random.choice(templates) if "{concept}" in t: replacements = {"concept": random.choice(concepts)} elif "{num1}" in t: replacements = {"num1": random.choice(nums), "num2": random.choice(nums)} elif "{lang}" in t: replacements = {"lang": random.choice(langs), "func": random.choice(funcs)} elif "{thing}" in t: replacements = {"thing": random.choice(things), "state": random.choice(states)} else: replacements = {}
query = t.format(**replacements) return apply_synonym_variation(query)
def generate_professional_query_rule(): """规则生成专业咨询查询""" templates = [ "请问{subject}课程的学费是多少?", "{subject}的课程大纲是什么?", "{subject}培训的学习周期有多长?", "请介绍一下{subject}培训的主要项目内容。", "请问{subject}培训地点在哪里?" ] subjects = ["JAVA", "AI", "测试", "Web前端", "Python", "大数据", "DevOps"] t = random.choice(templates) query = t.format(subject=random.choice(subjects)) return apply_synonym_variation(query)
# 大模型生成部分def generate_with_qwen(prompt): """调用Qwen-Plus生成查询,添加超时控制""" try: completion = client.chat.completions.create( model="qwen-plus", messages=[{"role": "user", "content": prompt}], temperature=0.9, timeout=10 ) return completion.choices[0].message.content.strip() except Exception as e: print(f"Qwen-Plus调用失败: {e}") return None
def generate_generic_query_qwen(): """使用Qwen-Plus生成通用知识查询""" prompt = """ 你是一个用户,生成一个“通用知识”类的查询,涉及数学计算、代码生成/纠错、概念与原理或常识性问题。 示例: - “3+5等于多少?” - “写一个Python排序函数” - “什么是神经网络?” - “太阳为什么是热的?” 请生成一个类似的查询,直接返回查询文本,不要多余说明。 """ return generate_with_qwen(prompt)
def generate_professional_query_qwen(): """使用Qwen-Plus生成专业咨询查询""" prompt = """ 你是一个用户,生成一个“专业咨询”类的查询,涉及IT教育培训(如课程详情、师资、费用、周期、地点等)。 示例: - “JAVA课程费用多少?” - “AI培训有哪些老师?” - “测试课程什么时候开课?” 请生成一个类似的查询,直接返回查询文本,不要多余说明。 """ return generate_with_qwen(prompt)
# 保存数据集def save_dataset(dataset, filename, stage_name): """保存数据集到指定文件""" with open(filename, "w", encoding="utf-8") as f: json.dump(dataset, f, ensure_ascii=False, indent=2) print(f"{stage_name}:已保存 {len(dataset)} 条数据到 {filename}")
def generate_training_dataset(total_samples=6000): """生成6000条训练数据,规则和大模型各占一半,每1500条保存""" num_per_category = total_samples // 2 # 3000 num_rule = num_per_category // 2 # 1500 num_qwen = num_per_category - num_rule # 1500
generic_samples = [] professional_samples = [] generic_set = set() professional_set = set()
# 规则生成 - 通用知识 print("生成规则通用知识数据...") with tqdm(total=num_rule, desc="Rule-based Generic") as pbar: while len(generic_samples) < num_rule: q = generate_generic_query_rule() if q not in generic_set: generic_set.add(q) generic_samples.append({"query": q, "label": "通用知识"}) pbar.update(1) save_dataset(generic_samples, "rule_generic_1500.json", "规则通用知识")
# 规则生成 - 专业咨询 print("生成规则专业咨询数据...") with tqdm(total=num_rule, desc="Rule-based Professional") as pbar: while len(professional_samples) < num_rule: q = generate_professional_query_rule() if q not in professional_set: professional_set.add(q) professional_samples.append({"query": q, "label": "专业咨询"}) pbar.update(1) save_dataset(professional_samples, "rule_professional_1500.json", "规则专业咨询")
# Qwen-Plus生成 - 通用知识 print("生成Qwen-Plus通用知识数据...") with tqdm(total=num_qwen, desc="Qwen-based Generic") as pbar: while len(generic_samples) < num_per_category: q = generate_generic_query_qwen() if q and q not in generic_set: generic_set.add(q) generic_samples.append({"query": q, "label": "通用知识"}) pbar.update(1) time.sleep(0.5) # 避免 API 限流 save_dataset(generic_samples, "generic_3000.json", "通用知识(规则+Qwen)")
# Qwen-Plus生成 - 专业咨询 print("生成Qwen-Plus专业咨询数据...") with tqdm(total=num_qwen, desc="Qwen-based Professional") as pbar: while len(professional_samples) < num_per_category: q = generate_professional_query_qwen() if q and q not in professional_set: professional_set.add(q) professional_samples.append({"query": q, "label": "专业咨询"}) pbar.update(1) time.sleep(0.5) # 避免 API 限流 save_dataset(professional_samples, "professional_3000.json", "专业咨询(规则+Qwen)")
# 合并并混洗 dataset = generic_samples + professional_samples random.shuffle(dataset) final_filename = "training_dataset_hybrid_6000.json" save_dataset(dataset, final_filename, "最终数据集") return dataset
if __name__ == "__main__": dataset = generate_training_dataset(total_samples=6000) print(f"成功生成 {len(dataset)} 条训练数据,保存在 training_dataset_hybrid_6000.json 文件中。") # 输出前10条作为示例 for item in dataset[:10]: print(json.dumps(item, ensure_ascii=False))apply_synonym_variation:- 作用:为规则生成查询增加多样性,通过同义词替换(如“什么” -> “啥是”)。
- 逻辑:正则表达式匹配词边界,50% 概率替换。
generate_generic_query_rule和generate_professional_query_rule:- 作用:基于模板生成初始数据。
- 设计:覆盖数学计算、代码生成、概念问题和 IT 培训场景。
generate_with_qwen:- 作用:封装 Qwen-Plus API 调用,生成自然语言查询。
- 参数:
temperature=0.9确保多样性,timeout=10防止卡顿。
generate_generic_query_qwen和generate_professional_query_qwen:- 作用:通过精心设计的 Prompt 指导 Qwen-Plus 生成符合类别的查询。
- 逻辑:提供示例,输出简洁的查询文本。
save_dataset:- 作用:统一保存数据集,显示阶段名称和数据量。
- 实现:支持 JSON 格式,保存中间和最终结果。
generate_training_dataset:- 作用:整合规则和大模型生成流程。
- 流程:
- 规则生成 1500 条“通用知识”,保存。
- 规则生成 1500 条“专业咨询”,保存。
- Qwen-Plus 生成 1500 条“通用知识”,累计 3000 条保存。
- Qwen-Plus 生成 1500 条“专业咨询”,累计 3000 条保存。
- 合并混洗 6000 条,保存最终数据集。
- 进度条:
tqdm实时显示每个阶段的生成进度。
- 混合生成:规则生成高效可控,Qwen-Plus 生成自然真实。
- 分阶段保存:每 1500 条保存一次,支持断点恢复。
- 进度监控:
tqdm提供直观反馈,优化用户体验。
运行脚本时,输出类似:
生成规则通用知识数据...Rule-based Generic: 100%|██████████| 1500/1500 [00:02<00:00, 750it/s]规则通用知识:已保存 1500 条数据到 rule_generic_1500.json生成规则专业咨询数据...Rule-based Professional: 100%|██████████| 1500/1500 [00:02<00:00, 700it/s]规则专业咨询:已保存 1500 条数据到 rule_professional_1500.json生成Qwen-Plus通用知识数据...Qwen-based Generic: 100%|██████████| 1500/1500 [05:00<00:00, 5it/s]通用知识(规则+Qwen):已保存 3000 条数据到 generic_3000.json生成Qwen-Plus专业咨询数据...Qwen-based Professional: 100%|██████████| 1500/1500 [05:05<00:00, 4.9it/s]专业咨询(规则+Qwen):已保存 3000 条数据到 professional_3000.json最终数据集:已保存 6000 条数据到 training_dataset_hybrid_6000.json成功生成 6000 条训练数据,保存在 training_dataset_hybrid_6000.json 文件中。1.2 数据集生成流程
Section titled “1.2 数据集生成流程”- 规则生成(3000 条):
- 生成 1500 条“通用知识”数据,保存为
rule_generic_1500.json。 - 生成 1500 条“专业咨询”数据,保存为
rule_professional_1500.json. - 使用模板和同义词替换,确保多样性。
- 生成 1500 条“通用知识”数据,保存为
- Qwen-Plus 生成(3000 条):
- 生成 1500 条“通用知识”数据,累计 3000 条保存为
generic_3000.json。 - 生成 1500 条“专业咨询”数据,累计 3000 条保存为
professional_3000.json. - 通过 Prompt 控制生成质量。
- 生成 1500 条“通用知识”数据,累计 3000 条保存为
- 数据整合:
- 合并 6000 条数据,随机混洗。
- 保存为
training_dataset_hybrid_6000.json.
代码示例(使用数据集)
Section titled “代码示例(使用数据集)”from sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.naive_bayes import MultinomialNBfrom sklearn.pipeline import Pipelineimport joblib
class QueryClassifier: def train_model(self): with open("training_dataset_hybrid_6000.json", "r", encoding="utf-8") as f: data = json.load(f) texts = [item["query"] for item in data] labels = [item["label"] for item in data] self.model = Pipeline([ ("tfidf", TfidfVectorizer()), ("classifier", MultinomialNB()), ]) self.model.fit(texts, labels) joblib.dump(self.model, "query_classifier_model.pkl") print("模型训练完成并保存")1.3 数据集特点与优化
Section titled “1.3 数据集特点与优化”- 总数:6000 条(“通用知识” 3000 条,“专业咨询” 3000 条)。
- 来源:规则生成 50%(3000 条),Qwen-Plus 生成 50%(3000 条)。
- 多样性:
- 规则生成:模板和同义词替换覆盖多种场景。
- Qwen-Plus 生成:自然语言查询贴近真实用户输入。
- 保存机制:每 1500 条保存,支持断点续传。
- 可视化:
tqdm进度条提供实时反馈。
- 高效性:简化生成流程,规则生成瞬时完成。
- 可靠性:分阶段保存确保数据安全。
- 分类性能:均衡数据集提升
QueryClassifier准确性。 - 用户体验:进度条和保存日志增强交互性。
本章节详细介绍了优化后的 generate_query_dataset_hybrid.py: - 功能:混合生成 6000 条数据,规则和 Qwen-Plus 各占一半。 - 实现:通过模板生成 3000 条,Qwen-Plus 生成 3000 条,每 1500 条保存,进度条监控。 - 作用:为 QueryClassifier 提供高质量数据,优化 EduRAG 系统分类能力。
学习者掌握了高效的数据生成和保存方法,能够为 RAG 系统提供可靠支持。
发布于 2025-12-31