模型剪枝
发布于 2025-12-31
4.4 模型剪枝
Section titled “4.4 模型剪枝”学习目标
- 理解什么是模型剪枝.
- 掌握模型剪枝的基本操作.

1.模型剪枝定义
Section titled “1.模型剪枝定义”- 基于深度神经网络的大型预训练模型拥有庞大的参数量, 才能达到SOTA的效果. 但是我们参考生物的神经网络, 发现却是依靠大量稀疏的连接来完成复杂的意识活动.
- 仿照生物的稀疏神经网络, 将大型网络中的稠密连接变成稀疏的连接, 并同样达到SOTA的效果, 就是模型剪枝的原动力.

- Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:
- 对特定网络模块的剪枝(Pruning Model).
- 多参数模块的剪枝(Pruning multiple parameters).
- 全局剪枝(GLobal pruning).
- 用户自定义剪枝(Custom pruning).
- 注意: 保证Pytorch的版本在1.4.0以上, 支持剪枝操作.
2.代码实现
Section titled “2.代码实现”2.1 配置文件Config
Section titled “2.1 配置文件Config”import torchimport osimport datetimefrom transformers.models import BertModel,BertTokenizer,BertConfigcurrent_date=datetime.datetime.now().date().strftime("%Y%m%d")
class Config(object): def __init__(self): """ 配置类,包含模型和训练所需的各种参数。 """ self.model_name = "bert" # 模型名称 self.data_path = "../../01-data" #数据集的根路径 self.train_path = self.data_path + "\\train.txt" # 训练集 self.dev_path = self.data_path + "\\dev3.txt" # 少量验证集,快速验证 self.test_path = self.data_path + "\\test.txt" # 测试集
self.class_path=self.data_path + "\\class.txt" #类别文件
self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")] # 类别名单
# BERT模型训练结果保存路径 self.model_save_path = "./models_save/bert20250521.pt" # 剪枝模型训练结果保存路径 self.prune_model_save_path = "./models_save/prune_bertclassifer_model.pt"
# 模型训练+预测的时候 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备,如果GPU可用,则为cuda,否则为cpu
self.num_classes = len(self.class_list) # 类别数 self.num_epochs = 2 # epoch数 self.batch_size = 256 # mini-batch大小 self.pad_size = 32 # 每句话处理成的长度(短填长切) self.learning_rate = 5e-5 # 学习率 self.bert_path = "../../04-bert/bert-base-chinese" # 预训练BERT模型的路径 self.bert_model=BertModel.from_pretrained(self.bert_path) self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) # BERT模型的分词器 self.bert_config = BertConfig.from_pretrained(self.bert_path) # BERT模型的配置 self.hidden_size = 768 # BERT模型的隐藏层大小
if __name__ == '__main__': conf = Config() print(conf.bert_config) input_size=conf.tokenizer.convert_tokens_to_ids(["你","好","中国","人"]) print(input_size) print(conf.class_list)2.2 模型剪枝
Section titled “2.2 模型剪枝”(1)导入依赖包
Section titled “(1)导入依赖包”import torchimport torch.nn as nnimport torch.nn.utils.prune as prunefrom transformers import BertModelfrom bert_classifer_model import BertClassifierfrom utils import build_dataloaderfrom train import model2devfrom tqdm import tqdmfrom itertools import islice(2)定义稀疏度计算函数
Section titled “(2)定义稀疏度计算函数”def compute_sparsity(model): """计算所有 encoder 层 query 权重的稀疏度""" total_params = 0 zero_params = 0 for i in range(12): weight = model.bert.encoder.layer[i].attention.self.query.weight total_params += weight.numel() zero_params += (weight == 0).sum().item() return zero_params / total_params if total_params > 0 else 0(3)定义权重打印函数
Section titled “(3)定义权重打印函数”def print_weights(weight, name, rows=5, cols=5): """打印权重矩阵的前 rows x cols 部分""" print(f"\n{name}(前 {rows}x{cols}):") print(weight[:rows, :cols])(4)主函数
Section titled “(4)主函数”BERT 全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。
def main(): train_dataloader, test_dataloader, dev_dataloader = build_dataloader()
# 加载模型 model = BertClassifier().to(conf.device) model.load_state_dict(torch.load(conf.model_save_path), strict=False)
# 剪枝前 print("剪枝前模型:") print(model.bert.encoder.layer[0].attention.self) print_weights(model.bert.encoder.layer[0].attention.self.query.weight, "layer[0].attention.self.query.weight 剪枝前") report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device) print(f"\n剪枝前准确率: {accuracy:.4f}, F1: {f1score:.4f}")
# 全局非结构化剪枝:所有 encoder 层 query 权重 30% parameters_to_prune = [(model.bert.encoder.layer[i].attention.self.query, 'weight') for i in range(12)] prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.3)
for module, param in parameters_to_prune: prune.remove(module, param)
# 剪枝后 print("\n剪枝后模型:") print(model.bert.encoder.layer[0].attention.self) print_weights(model.bert.encoder.layer[0].attention.self.query.weight, "layer[0].attention.self.query.weight 剪枝后") report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device) sparsity = compute_sparsity(model) print(f"\n剪枝后准确率: {accuracy:.4f}, F1: {f1score:.4f}\n稀疏度: {sparsity:.4f}")
# 模型保存 torch.save(model.state_dict(), conf.prune_model_save_path)
if __name__ == '__main__': # 1.加载配置文件 conf = Config() # 2.调用主函数 main()输出日志:
XXXXX/bin/python XXXXX/TMFCode/06-model-compression/bert_prune/prune_bert_attention.pyLoading data: 180000it [00:00, 711615.96it/s]Loading data: 10000it [00:00, 409776.08it/s]Loading data: 50it [00:00, 187245.71it/s]
剪枝前模型:BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False))
layer[0].attention.self.query.weight 剪枝前(前 5x5):tensor([[ 0.1152, -0.0104, 0.0063, 0.0414, -0.0410], [ 0.0050, -0.0232, -0.0065, 0.0219, 0.0891], [ 0.0138, 0.0019, 0.0359, -0.0140, -0.0088], [ 0.0024, -0.0525, -0.0323, 0.0530, -0.0187], [-0.0470, 0.0525, 0.0182, -0.0156, 0.0729]], grad_fn=<SliceBackward0>)Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00, 1.46s/it]
剪枝前准确率: 0.9600, F1: 0.9407
剪枝后模型:BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False))
layer[0].attention.self.query.weight 剪枝后(前 5x5):tensor([[ 0.1152, -0.0000, 0.0000, 0.0414, -0.0410], [ 0.0000, -0.0232, -0.0000, 0.0219, 0.0891], [ 0.0000, 0.0000, 0.0359, -0.0000, -0.0000], [ 0.0000, -0.0525, -0.0323, 0.0530, -0.0187], [-0.0470, 0.0525, 0.0182, -0.0000, 0.0729]], grad_fn=<SliceBackward0>)Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00, 1.30s/it]
剪枝后准确率: 0.9400, F1: 0.9187稀疏度: 0.3000
Process finished with exit code 0结论:经过全局剪枝操作后,模型的F1为91.87%,相较于最好的指标,下降2个百分点左右。
3.本节小结
Section titled “3.本节小结”- 本部分完成了全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。
发布于 2025-12-31