Skip to content

fix: pattern match gelu from contrib and onnx ops🐛 #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions onnxscript/rewriter/ort_fusions/bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,48 @@


class BiasGeluFusion(pattern.RewriteRuleClassBase):
"""Fuses a Bias-Gelu pattern into a single BiasGelu operator.

Attributes:
contrib_op (bool): If True, matches the Gelu operator from the 'com.microsoft' domain.
If False, matches the standard ONNX Gelu operator.
"""

def __init__(
self,
name: str,
*,
contrib_op: bool,
):
super().__init__(name)
self._contrib_op = contrib_op

def pattern(self, op, x, y):
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, _domain="com.microsoft")

def rewrite(self, op, x, y):
if self._contrib_op:
return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"])
else:
return op.Gelu(gelu_add, _outputs=["gelu"])

def check(self, op, gelu, **_) -> pattern.MatchResult:
check_result = pattern.MatchResult()
approximate = gelu.producer().attributes.get_string("approximate")
if approximate is not None and approximate == "tanh":
return check_result.fail(
"Gelu operator with 'approximate' set to 'tanh' is not supported."
)
return check_result

def rewrite(self, op, x, y, **_):
return op.BiasGelu(x, y, _domain="com.microsoft")


_rule = BiasGeluFusion.rule()

bias_gelu_rules = pattern.RewriteRuleSet([_rule])
bias_gelu_rules = pattern.RewriteRuleSet(
[
BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False),
BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True),
]
)


fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)
83 changes: 67 additions & 16 deletions onnxscript/rewriter/ort_fusions/bias_gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,52 @@
import unittest

import numpy as np
import parameterized

import onnxscript
import onnxscript.ir as ir
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
from onnxscript import FLOAT, script
from onnxscript import opset18 as op
from onnxscript import FLOAT, OnnxFunction, script
from onnxscript import opset20 as op
from onnxscript.optimizer import optimize, remove_unused_nodes
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu

msft_op = onnxscript.values.Opset("com.microsoft", 1)


@script()
def _test_script_onnx_default(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add)

Check warning on line 23 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L22-L23

Added lines #L22 - L23 were not covered by tests


@script()
def _test_script_onnx_none(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, approximate="none")

Check warning on line 29 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L28-L29

Added lines #L28 - L29 were not covered by tests


@script()
def _test_script_onnx_unsupported(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, approximate="tanh")

Check warning on line 35 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L34-L35

Added lines #L34 - L35 were not covered by tests


@script()
def _test_script_msft_op(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return msft_op.Gelu(gelu_add)

Check warning on line 41 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L40-L41

Added lines #L40 - L41 were not covered by tests


class BiasGeluFusionTest(unittest.TestCase):
def test_bias_gelu_fusion(self):
@script()
def bias_gelu_model(x, y):
gelu_add = op.Add(x, y)
gelu = msft_op.Gelu(gelu_add)
return gelu

model_proto = bias_gelu_model.to_model_proto(
input_types=[FLOAT[10], FLOAT[10]],
output_types=[FLOAT[10]],
ir_version=10,
)
def _check(
self,
test_data_constructor: OnnxFunction,
expected_graph_len: int,
expected_op_type: str,
):
"""Helper method to run a fusion test scenario."""
model_proto = test_data_constructor.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
optimize(model)

Expand All @@ -41,12 +62,42 @@
fuse_bias_gelu(model)
remove_unused_nodes(model)

self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph.node(0).op_type, "BiasGelu")
self.assertEqual(len(model.graph), expected_graph_len)
self.assertEqual(model.graph.node(0).op_type, expected_op_type)

optimized_output = test_utils.ort_run("Optimized", model, input)
test_utils.assert_allclose(original_output, optimized_output)

@parameterized.parameterized.expand(
[
("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"),
("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"),
("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"),
]
)
def test_bias_gelu_fusion(
self,
_,
test_data_constructor: OnnxFunction,
expected_graph_len: int,
expected_op_type: str,
):
self._check(test_data_constructor, expected_graph_len, expected_op_type)

@parameterized.parameterized.expand(
[
("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"),
]
)
def test_bias_gelu_fusion_unsupported_attr(
self,
_,
test_data_constructor: OnnxFunction,
expected_graph_len: int,
expected_op_type: str,
):
self._check(test_data_constructor, expected_graph_len, expected_op_type)


if __name__ == "__main__":
unittest.main()
Loading