diff --git a/bbcm/modeling/csc/modeling_soft_masked_bert_onnx.py b/bbcm/modeling/csc/modeling_soft_masked_bert_onnx.py new file mode 100644 index 0000000..a93cdc5 --- /dev/null +++ b/bbcm/modeling/csc/modeling_soft_masked_bert_onnx.py @@ -0,0 +1,147 @@ +""" +@Time : 2025-03-24 20:24:00 +@File : modeling_soft_masked_bert_onnx.py +@Author : Abtion, Zhang Chen +@Email : abtion{at}outlook.com, zhangchen.shaanxi{at}gmail.com +""" +import operator +import os +from collections import OrderedDict +import transformers as tfs +import torch +from torch import nn +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from transformers import BertConfig +from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertOnlyMLMHead +from transformers.modeling_utils import ModuleUtilsMixin +from bbcm.engine.csc_trainer import CscTrainingModel +import numpy as np + + +class DetectionNetwork(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.gru = nn.GRU( + self.config.hidden_size, + self.config.hidden_size // 2, + num_layers=2, + batch_first=True, + dropout=self.config.hidden_dropout_prob, + bidirectional=True, + ) + self.sigmoid = nn.Sigmoid() + self.linear = nn.Linear(self.config.hidden_size, 1) + + def forward(self, hidden_states): + out, _ = self.gru(hidden_states) + prob = self.linear(out) + prob = self.sigmoid(prob) + return prob + + +class BertCorrectionModel(torch.nn.Module, ModuleUtilsMixin): + def __init__(self, config, tokenizer, device): + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.embeddings = BertEmbeddings(self.config) + self.corrector = BertEncoder(self.config) + self.mask_token_id = self.tokenizer.mask_token_id + self.cls = BertOnlyMLMHead(self.config) + self._device = device + + def forward(self, prob, embed=None, attention_mask=None,residual_connection=False): + + + # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。 + mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach() + # 此处为原文实现 + # mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach() + cor_embed = prob * mask_embed + (1 - prob) * embed + + input_shape = embed.size() + device = embed.device + + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, + input_shape, device) + head_mask = self.get_head_mask(None, self.config.num_hidden_layers) + # print(f"cor_embed.shape: {cor_embed.shape}") + encoder_outputs = self.corrector( + cor_embed, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + return_dict=False, + ) + sequence_output = encoder_outputs[0] + + sequence_output = sequence_output + embed if residual_connection else sequence_output + prediction_scores = self.cls(sequence_output) + out = (prediction_scores, sequence_output) + + + return out + + def load_from_transformers_state_dict(self, gen_fp): + state_dict = OrderedDict() + gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict() + for k, v in gen_state_dict.items(): + name = k + if name.startswith('bert'): + name = name[5:] + if name.startswith('encoder'): + name = f'corrector.{name[8:]}' + if 'gamma' in name: + name = name.replace('gamma', 'weight') + if 'beta' in name: + name = name.replace('beta', 'bias') + state_dict[name] = v + self.load_state_dict(state_dict, strict=False) + + +class SoftMaskedBertModel(CscTrainingModel): + def __init__(self, cfg, tokenizer): + super().__init__(cfg) + self.cfg = cfg + self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT) + self.detector = DetectionNetwork(self.config) + self.tokenizer = tokenizer + self.corrector = BertCorrectionModel(self.config, tokenizer, cfg.MODEL.DEVICE) + self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT) + self._device = cfg.MODEL.DEVICE + + def forward(self, input_ids, attention_mask, token_type_ids): + # print(f"texts: {len(texts)}") + # [print((x,len(x))) for x in texts] + # encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt', truncation=True,) + # encoded_texts.to(self._device) + # print(f"encoded_texts['input_ids'].shape: {encoded_texts['input_ids'].shape}") + embed = self.corrector.embeddings(input_ids=input_ids, + token_type_ids=token_type_ids, + ) + # print(f"embed.shape: {embed.shape}") + prob = self.detector(embed) + # print() + # print(f"prob.shape: {prob.shape}") + # print(f"cor_labels.shape: {(len(cor_labels),len(cor_labels[0]),'...')}") + # [print((x,len(x))) for x in cor_labels] + # cor_out = self.corrector(texts, prob, embed, cor_labels, + # residual_connection=True) + cor_out = self.corrector(prob, embed,attention_mask, + residual_connection=True) + + + outputs = (prob.squeeze(-1),) + cor_out + + return outputs + + def load_from_transformers_state_dict(self, gen_fp): + """ + 从transformers加载预训练权重 + :param gen_fp: + :return: + """ + self.corrector.load_from_transformers_state_dict(gen_fp) diff --git a/tools/SoftMaskedBert_to_onnx.py b/tools/SoftMaskedBert_to_onnx.py new file mode 100644 index 0000000..589da47 --- /dev/null +++ b/tools/SoftMaskedBert_to_onnx.py @@ -0,0 +1,91 @@ +""" +@Time : 2025-03-24 20:24:00 +@File : SoftMaskedBert_to_onnx.py +@Author : Zhang Chen +@Email : zhangchen.shaanxi{at}gmail.com +""" + +from transformers import BertModel, BertTokenizer +from tools.inference import load_model_directly_onnx +import torch +from pathlib import Path +import onnxruntime as ort +import numpy as np + +if __name__ == '__main__': + ckpt_file = './checkpoints/SoftMaskedBert/epoch=09-val_loss=0.03032.ckpt' + config_file = 'csc/train_SoftMaskedBert.yml' + + tokenizer: BertTokenizer = BertTokenizer.from_pretrained('./bert-base-chinese') + onnx_model = load_model_directly_onnx(ckpt_file, config_file) + + print(f"onnx_model.device: {onnx_model.device}") + + inputs = tokenizer( + ["你好帮我生成一个荷花店的课件好吗行吧巴拉巴拉小魔哈哈哈哈红红火火恍恍惚惚"], + padding=True, + return_tensors="pt", + truncation = True, + ) + device = onnx_model.device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + token_type_ids = inputs["token_type_ids"].to(device) + + # 定义动态轴(处理可变 batch size 和序列长度) + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "token_type_ids": {0: "batch_size", 1: "sequence_length"}, + "prob": {0: "batch_size", 1: "sequence_length"}, + "output": {0: "batch_size", 1: "sequence_length"}, + "sequence_output": {0: "batch_size", 1: "sequence_length"} + } + + result_file = f"{Path(ckpt_file).parent}/{Path(ckpt_file).name[:-5]}.onnx" + # 导出为 ONNX + torch.onnx.export( + onnx_model, + (input_ids, attention_mask, token_type_ids), + result_file, + input_names=["input_ids", "attention_mask", "token_type_ids"], + output_names=["prob","output","sequence_output"], + dynamic_axes=dynamic_axes, + opset_version=12, # 建议使用 11 或更高版本 + ) + print(f"result_file: {result_file}") + + # 加载 ONNX 模型 + ort_session = ort.InferenceSession(result_file) + + # 准备输入数据 + inputs_onnx = { + "input_ids": input_ids.cpu().numpy(), + "attention_mask": attention_mask.cpu().numpy(), + "token_type_ids": token_type_ids.cpu().numpy() + } + + # 运行推理 + outputs = ort_session.run(None, inputs_onnx) + + # 比较原始模型和 ONNX 模型的输出 + with torch.no_grad(): + original_outputs = onnx_model(input_ids, attention_mask, token_type_ids) + + print(f"original_outputs[1].shape: {original_outputs[1].shape}") + print(f"len(outputs): {len(outputs)}") + print(f"outputs[1].shape: {outputs[1].shape}") + # 检查输出是否一致 + atol = 1e-5 + ref_value =original_outputs[1].cpu().numpy() + ort_value=outputs[1] + all_close = np.allclose(ref_value, + ort_value, + atol=atol) + if not all_close: + max_diff = np.amax(np.abs(ref_value - ort_value)) + # print(ref_value) + # print(ort_value) + print(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})") + else: + print(f"\t\t-[✓] all values close (atol: {atol})") \ No newline at end of file diff --git a/tools/inference.py b/tools/inference.py index 2f4be1a..2e3003c 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -54,6 +54,31 @@ def load_model_directly(ckpt_file, config_file): model.to(cfg.MODEL.DEVICE) return model +def load_model_directly_onnx(ckpt_file, config_file): + # Example: + # ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' (find in checkpoints) + # config_file = 'csc/train_SoftMaskedBert.yml' (find in configs) + + from bbcm.config import cfg + # cp = get_abs_path('checkpoints', ckpt_file) + cp = ckpt_file + cfg.merge_from_file(get_abs_path('configs', config_file)) + tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT, + model_max_length=512) + + if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: + model = BertForCsc.load_from_checkpoint(cp, + cfg=cfg, + tokenizer=tokenizer) + else: + from bbcm.modeling.csc.modeling_soft_masked_bert_onnx import SoftMaskedBertModel + model = SoftMaskedBertModel.load_from_checkpoint(cp, + # strict=False, + cfg=cfg, + tokenizer=tokenizer) + model.eval() + model.to(cfg.MODEL.DEVICE) + return model def load_model(args): from bbcm.config import cfg