Skip to content

export SoftMaskedBert to ONNX #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions bbcm/modeling/csc/modeling_soft_masked_bert_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
@Time : 2025-03-24 20:24:00
@File : modeling_soft_masked_bert_onnx.py
@Author : Abtion, Zhang Chen
@Email : abtion{at}outlook.com, zhangchen.shaanxi{at}gmail.com
"""
import operator
import os
from collections import OrderedDict
import transformers as tfs
import torch
from torch import nn
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertOnlyMLMHead
from transformers.modeling_utils import ModuleUtilsMixin
from bbcm.engine.csc_trainer import CscTrainingModel
import numpy as np


class DetectionNetwork(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gru = nn.GRU(
self.config.hidden_size,
self.config.hidden_size // 2,
num_layers=2,
batch_first=True,
dropout=self.config.hidden_dropout_prob,
bidirectional=True,
)
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(self.config.hidden_size, 1)

def forward(self, hidden_states):
out, _ = self.gru(hidden_states)
prob = self.linear(out)
prob = self.sigmoid(prob)
return prob


class BertCorrectionModel(torch.nn.Module, ModuleUtilsMixin):
def __init__(self, config, tokenizer, device):
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.embeddings = BertEmbeddings(self.config)
self.corrector = BertEncoder(self.config)
self.mask_token_id = self.tokenizer.mask_token_id
self.cls = BertOnlyMLMHead(self.config)
self._device = device

def forward(self, prob, embed=None, attention_mask=None,residual_connection=False):


# 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。
mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach()
# 此处为原文实现
# mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach()
cor_embed = prob * mask_embed + (1 - prob) * embed

input_shape = embed.size()
device = embed.device

extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask,
input_shape, device)
head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
# print(f"cor_embed.shape: {cor_embed.shape}")
encoder_outputs = self.corrector(
cor_embed,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
return_dict=False,
)
sequence_output = encoder_outputs[0]

sequence_output = sequence_output + embed if residual_connection else sequence_output
prediction_scores = self.cls(sequence_output)
out = (prediction_scores, sequence_output)


return out

def load_from_transformers_state_dict(self, gen_fp):
state_dict = OrderedDict()
gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict()
for k, v in gen_state_dict.items():
name = k
if name.startswith('bert'):
name = name[5:]
if name.startswith('encoder'):
name = f'corrector.{name[8:]}'
if 'gamma' in name:
name = name.replace('gamma', 'weight')
if 'beta' in name:
name = name.replace('beta', 'bias')
state_dict[name] = v
self.load_state_dict(state_dict, strict=False)


class SoftMaskedBertModel(CscTrainingModel):
def __init__(self, cfg, tokenizer):
super().__init__(cfg)
self.cfg = cfg
self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT)
self.detector = DetectionNetwork(self.config)
self.tokenizer = tokenizer
self.corrector = BertCorrectionModel(self.config, tokenizer, cfg.MODEL.DEVICE)
self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT)
self._device = cfg.MODEL.DEVICE

def forward(self, input_ids, attention_mask, token_type_ids):
# print(f"texts: {len(texts)}")
# [print((x,len(x))) for x in texts]
# encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt', truncation=True,)
# encoded_texts.to(self._device)
# print(f"encoded_texts['input_ids'].shape: {encoded_texts['input_ids'].shape}")
embed = self.corrector.embeddings(input_ids=input_ids,
token_type_ids=token_type_ids,
)
# print(f"embed.shape: {embed.shape}")
prob = self.detector(embed)
# print()
# print(f"prob.shape: {prob.shape}")
# print(f"cor_labels.shape: {(len(cor_labels),len(cor_labels[0]),'...')}")
# [print((x,len(x))) for x in cor_labels]
# cor_out = self.corrector(texts, prob, embed, cor_labels,
# residual_connection=True)
cor_out = self.corrector(prob, embed,attention_mask,
residual_connection=True)


outputs = (prob.squeeze(-1),) + cor_out

return outputs

def load_from_transformers_state_dict(self, gen_fp):
"""
从transformers加载预训练权重
:param gen_fp:
:return:
"""
self.corrector.load_from_transformers_state_dict(gen_fp)
91 changes: 91 additions & 0 deletions tools/SoftMaskedBert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@Time : 2025-03-24 20:24:00
@File : SoftMaskedBert_to_onnx.py
@Author : Zhang Chen
@Email : zhangchen.shaanxi{at}gmail.com
"""

from transformers import BertModel, BertTokenizer
from tools.inference import load_model_directly_onnx
import torch
from pathlib import Path
import onnxruntime as ort
import numpy as np

if __name__ == '__main__':
ckpt_file = './checkpoints/SoftMaskedBert/epoch=09-val_loss=0.03032.ckpt'
config_file = 'csc/train_SoftMaskedBert.yml'

tokenizer: BertTokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
onnx_model = load_model_directly_onnx(ckpt_file, config_file)

print(f"onnx_model.device: {onnx_model.device}")

inputs = tokenizer(
["你好帮我生成一个荷花店的课件好吗行吧巴拉巴拉小魔哈哈哈哈红红火火恍恍惚惚"],
padding=True,
return_tensors="pt",
truncation = True,
)
device = onnx_model.device
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
token_type_ids = inputs["token_type_ids"].to(device)

# 定义动态轴(处理可变 batch size 和序列长度)
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
"prob": {0: "batch_size", 1: "sequence_length"},
"output": {0: "batch_size", 1: "sequence_length"},
"sequence_output": {0: "batch_size", 1: "sequence_length"}
}

result_file = f"{Path(ckpt_file).parent}/{Path(ckpt_file).name[:-5]}.onnx"
# 导出为 ONNX
torch.onnx.export(
onnx_model,
(input_ids, attention_mask, token_type_ids),
result_file,
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["prob","output","sequence_output"],
dynamic_axes=dynamic_axes,
opset_version=12, # 建议使用 11 或更高版本
)
print(f"result_file: {result_file}")

# 加载 ONNX 模型
ort_session = ort.InferenceSession(result_file)

# 准备输入数据
inputs_onnx = {
"input_ids": input_ids.cpu().numpy(),
"attention_mask": attention_mask.cpu().numpy(),
"token_type_ids": token_type_ids.cpu().numpy()
}

# 运行推理
outputs = ort_session.run(None, inputs_onnx)

# 比较原始模型和 ONNX 模型的输出
with torch.no_grad():
original_outputs = onnx_model(input_ids, attention_mask, token_type_ids)

print(f"original_outputs[1].shape: {original_outputs[1].shape}")
print(f"len(outputs): {len(outputs)}")
print(f"outputs[1].shape: {outputs[1].shape}")
# 检查输出是否一致
atol = 1e-5
ref_value =original_outputs[1].cpu().numpy()
ort_value=outputs[1]
all_close = np.allclose(ref_value,
ort_value,
atol=atol)
if not all_close:
max_diff = np.amax(np.abs(ref_value - ort_value))
# print(ref_value)
# print(ort_value)
print(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})")
else:
print(f"\t\t-[✓] all values close (atol: {atol})")
25 changes: 25 additions & 0 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ def load_model_directly(ckpt_file, config_file):
model.to(cfg.MODEL.DEVICE)
return model

def load_model_directly_onnx(ckpt_file, config_file):
# Example:
# ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' (find in checkpoints)
# config_file = 'csc/train_SoftMaskedBert.yml' (find in configs)

from bbcm.config import cfg
# cp = get_abs_path('checkpoints', ckpt_file)
cp = ckpt_file
cfg.merge_from_file(get_abs_path('configs', config_file))
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT,
model_max_length=512)

if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
model = BertForCsc.load_from_checkpoint(cp,
cfg=cfg,
tokenizer=tokenizer)
else:
from bbcm.modeling.csc.modeling_soft_masked_bert_onnx import SoftMaskedBertModel
model = SoftMaskedBertModel.load_from_checkpoint(cp,
# strict=False,
cfg=cfg,
tokenizer=tokenizer)
model.eval()
model.to(cfg.MODEL.DEVICE)
return model

def load_model(args):
from bbcm.config import cfg
Expand Down