Skip to content

Commit 59340c6

Browse files
authored
fix: check for rank of bias in bias-gelu fusion🐛 (#2393)
Follow-up to #2364. I noticed that the current implementation `BiasGeluFusion` from #2364 does not check for the dimensions of the bias term, which can lead to errors, as the bias input for `BiasGelu(...)` is expected to be 1D (see [here](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftbiasgelu)). **minimal, complete example** with: ```sh uv pip install git+https://github.com/mircosoft/onnxscript.git --force-reinstall ``` ```python import os import numpy as np import onnx_ir as ir import torch from onnxscript.rewriter.ort_fusions._core import fuse_xformers from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import onnxruntime as ort os.environ["TOKENIZERS_PARALLELISM"] = "false" model_name = "hf-internal-testing/tiny-random-bart" model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model.eval() class EncoderWrapper(torch.nn.Module): """A wrapper around the BART encoder for onnx export.""" def __init__(self, encoder: torch.nn.Module): super().__init__() self.encoder = encoder def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: outs = self.encoder(input_ids, attention_mask) return outs["last_hidden_state"] model = EncoderWrapper(encoder=model.model.encoder) print(model) text = "God bless the internet." inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] input_names = ["input_ids"] output_names = ["encoder_output"] onnx_path = "bart_encoder.onnx" torch.onnx.export( model, (input_ids,), onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "encoder_output": {0: "batch_size", 1: "sequence_length"}, }, opset_version=20, ) onnx_model = ir.load(onnx_path) onnx_model, stats = fuse_xformers(onnx_model) print(stats) optimized_path = "optimized_model.onnx" ir.save(onnx_model, optimized_path) sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()}) sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"]) encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()}) abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0])) print("abs_difference", abs_diff) ``` ``` Applied 1 of general pattern rewrite rules. {'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2} 2025-06-15 20:52:33.994324 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge. 2025-06-15 20:52:33.994582 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge. 2025-06-15 20:52:34.007963 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008178 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008753 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008944 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.018753 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3 ... onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3 ``` with: ```sh uv pip install git+https://github.com/karelze/onnxscript.git@fix-bias-gelu-shape --force-reinstall ``` ``` Applied 1 of general pattern rewrite rules. {'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2} abs_difference 0.0 ``` This pr adds: - additional checks for dim of bias - additional test cases Sorry for the inconvenience. @justinchuby @titaiwangms
1 parent b76e1b3 commit 59340c6

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

onnxscript/rewriter/ort_fusions/bias_gelu.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from onnxscript.rewriter import _fusion_utils, pattern
5+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
66

77

88
class BiasGeluFusion(pattern.RewriteRuleClassBase):
@@ -22,30 +22,35 @@ def __init__(
2222
super().__init__(name)
2323
self._contrib_op = contrib_op
2424

25-
def pattern(self, op, x, y):
26-
gelu_add = op.Add(x, y)
25+
def pattern(self, op, input, bias):
26+
gelu_add = op.Add(input, bias)
27+
2728
if self._contrib_op:
2829
return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"])
2930
else:
3031
return op.Gelu(gelu_add, _outputs=["gelu"])
3132

32-
def check(self, op, gelu, **_) -> pattern.MatchResult:
33+
def check(self, op, gelu, input, bias, **_) -> pattern.MatchResult:
3334
check_result = pattern.MatchResult()
3435
approximate = gelu.producer().attributes.get_string("approximate")
3536
if approximate is not None and approximate == "tanh":
3637
return check_result.fail(
3738
"Gelu operator with 'approximate' set to 'tanh' is not supported."
3839
)
40+
41+
if not _ir_utils.has_rank(bias, 1):
42+
return check_result.fail("bias is not of shape 1D tensor", bias)
43+
3944
return check_result
4045

41-
def rewrite(self, op, x, y, **_):
42-
return op.BiasGelu(x, y, _domain="com.microsoft")
46+
def rewrite(self, op, input, bias, **_):
47+
return op.BiasGelu(input, bias, _domain="com.microsoft")
4348

4449

4550
bias_gelu_rules = pattern.RewriteRuleSet(
4651
[
47-
BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False),
48-
BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True),
52+
*BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False).commute(),
53+
*BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True).commute(),
4954
]
5055
)
5156

onnxscript/rewriter/ort_fusions/bias_gelu_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,39 @@
1818

1919

2020
@script()
21-
def _test_script_onnx_default(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
21+
def _test_script_onnx_default(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
2222
gelu_add = op.Add(x, y)
2323
return op.Gelu(gelu_add)
2424

2525

2626
@script()
27-
def _test_script_onnx_none(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
27+
def _test_script_onnx_none(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
2828
gelu_add = op.Add(x, y)
2929
return op.Gelu(gelu_add, approximate="none")
3030

3131

3232
@script()
33-
def _test_script_onnx_unsupported(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
33+
def _test_script_msft_op(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
3434
gelu_add = op.Add(x, y)
35-
return op.Gelu(gelu_add, approximate="tanh")
35+
return msft_op.Gelu(gelu_add)
36+
37+
38+
@script()
39+
def _test_script_reversed_order(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
40+
gelu_add = op.Add(y, x)
41+
return op.Gelu(gelu_add)
3642

3743

3844
@script()
39-
def _test_script_msft_op(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
45+
def _test_script_onnx_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
4046
gelu_add = op.Add(x, y)
41-
return msft_op.Gelu(gelu_add)
47+
return op.Gelu(gelu_add, approximate="tanh")
48+
49+
50+
@script()
51+
def _test_script_shape_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
52+
gelu_add = op.Add(x, x)
53+
return op.Gelu(gelu_add)
4254

4355

4456
class BiasGeluFusionTest(unittest.TestCase):
@@ -54,7 +66,7 @@ def _check(
5466
optimize(model)
5567

5668
input = {
57-
"x": np.random.randn(10).astype(np.float32),
69+
"x": np.random.randn(10, 10).astype(np.float32),
5870
"y": np.random.randn(10).astype(np.float32),
5971
}
6072
original_output = test_utils.ort_run("Original", model, input)
@@ -73,6 +85,7 @@ def _check(
7385
("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"),
7486
("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"),
7587
("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"),
88+
("reversed_order", _test_script_reversed_order, 1, "BiasGelu"),
7689
]
7790
)
7891
def test_bias_gelu_fusion(
@@ -87,6 +100,7 @@ def test_bias_gelu_fusion(
87100
@parameterized.parameterized.expand(
88101
[
89102
("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"),
103+
("unsupported_shape", _test_script_shape_unsupported, 2, "Add"),
90104
]
91105
)
92106
def test_bias_gelu_fusion_unsupported_attr(

0 commit comments

Comments
 (0)