Skip to content

Commit ccaefc6

Browse files
KarelZejustinchuby
andauthored
fix: pattern match gelu from contrib and onnx ops🐛 (#2364)
Previously the domain for Gelu in the [rules implementation](https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/ort_fusions/bias_gelu.py#L11) was restricted to the [contributor ops implementation](https://github.com/microsoft/onnxruntime/blob/rel-1.20.0/docs/ContribOperators.md#com.microsoft.Gelu) and does not fuse Gelu from onnx ops ([introduced with opset 20](https://onnx.ai/onnx/operators/onnx__Gelu.html#l-onnx-doc-gelu)). This pr introduces pattern matching + tests for both variants. closes #2362 . @shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 6e6f521 commit ccaefc6

File tree

2 files changed

+104
-22
lines changed

2 files changed

+104
-22
lines changed

onnxscript/rewriter/ort_fusions/bias_gelu.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,48 @@
66

77

88
class BiasGeluFusion(pattern.RewriteRuleClassBase):
9+
"""Fuses a Bias-Gelu pattern into a single BiasGelu operator.
10+
11+
Attributes:
12+
contrib_op (bool): If True, matches the Gelu operator from the 'com.microsoft' domain.
13+
If False, matches the standard ONNX Gelu operator.
14+
"""
15+
16+
def __init__(
17+
self,
18+
name: str,
19+
*,
20+
contrib_op: bool,
21+
):
22+
super().__init__(name)
23+
self._contrib_op = contrib_op
24+
925
def pattern(self, op, x, y):
1026
gelu_add = op.Add(x, y)
11-
return op.Gelu(gelu_add, _domain="com.microsoft")
12-
13-
def rewrite(self, op, x, y):
27+
if self._contrib_op:
28+
return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"])
29+
else:
30+
return op.Gelu(gelu_add, _outputs=["gelu"])
31+
32+
def check(self, op, gelu, **_) -> pattern.MatchResult:
33+
check_result = pattern.MatchResult()
34+
approximate = gelu.producer().attributes.get_string("approximate")
35+
if approximate is not None and approximate == "tanh":
36+
return check_result.fail(
37+
"Gelu operator with 'approximate' set to 'tanh' is not supported."
38+
)
39+
return check_result
40+
41+
def rewrite(self, op, x, y, **_):
1442
return op.BiasGelu(x, y, _domain="com.microsoft")
1543

1644

17-
_rule = BiasGeluFusion.rule()
18-
19-
bias_gelu_rules = pattern.RewriteRuleSet([_rule])
45+
bias_gelu_rules = pattern.RewriteRuleSet(
46+
[
47+
BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False),
48+
BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True),
49+
]
50+
)
2051

2152

2253
fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)

onnxscript/rewriter/ort_fusions/bias_gelu_test.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,52 @@
44
import unittest
55

66
import numpy as np
7+
import parameterized
78

89
import onnxscript
910
import onnxscript.ir as ir
1011
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
11-
from onnxscript import FLOAT, script
12-
from onnxscript import opset18 as op
12+
from onnxscript import FLOAT, OnnxFunction, script
13+
from onnxscript import opset20 as op
1314
from onnxscript.optimizer import optimize, remove_unused_nodes
1415
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
1516

1617
msft_op = onnxscript.values.Opset("com.microsoft", 1)
1718

1819

20+
@script()
21+
def _test_script_onnx_default(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
22+
gelu_add = op.Add(x, y)
23+
return op.Gelu(gelu_add)
24+
25+
26+
@script()
27+
def _test_script_onnx_none(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
28+
gelu_add = op.Add(x, y)
29+
return op.Gelu(gelu_add, approximate="none")
30+
31+
32+
@script()
33+
def _test_script_onnx_unsupported(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
34+
gelu_add = op.Add(x, y)
35+
return op.Gelu(gelu_add, approximate="tanh")
36+
37+
38+
@script()
39+
def _test_script_msft_op(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
40+
gelu_add = op.Add(x, y)
41+
return msft_op.Gelu(gelu_add)
42+
43+
1944
class BiasGeluFusionTest(unittest.TestCase):
20-
def test_bias_gelu_fusion(self):
21-
@script()
22-
def bias_gelu_model(x, y):
23-
gelu_add = op.Add(x, y)
24-
gelu = msft_op.Gelu(gelu_add)
25-
return gelu
26-
27-
model_proto = bias_gelu_model.to_model_proto(
28-
input_types=[FLOAT[10], FLOAT[10]],
29-
output_types=[FLOAT[10]],
30-
ir_version=10,
31-
)
45+
def _check(
46+
self,
47+
test_data_constructor: OnnxFunction,
48+
expected_graph_len: int,
49+
expected_op_type: str,
50+
):
51+
"""Helper method to run a fusion test scenario."""
52+
model_proto = test_data_constructor.to_model_proto()
3253
model = ir.serde.deserialize_model(model_proto)
3354
optimize(model)
3455

@@ -41,12 +62,42 @@ def bias_gelu_model(x, y):
4162
fuse_bias_gelu(model)
4263
remove_unused_nodes(model)
4364

44-
self.assertEqual(len(model.graph), 1)
45-
self.assertEqual(model.graph.node(0).op_type, "BiasGelu")
65+
self.assertEqual(len(model.graph), expected_graph_len)
66+
self.assertEqual(model.graph.node(0).op_type, expected_op_type)
4667

4768
optimized_output = test_utils.ort_run("Optimized", model, input)
4869
test_utils.assert_allclose(original_output, optimized_output)
4970

71+
@parameterized.parameterized.expand(
72+
[
73+
("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"),
74+
("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"),
75+
("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"),
76+
]
77+
)
78+
def test_bias_gelu_fusion(
79+
self,
80+
_,
81+
test_data_constructor: OnnxFunction,
82+
expected_graph_len: int,
83+
expected_op_type: str,
84+
):
85+
self._check(test_data_constructor, expected_graph_len, expected_op_type)
86+
87+
@parameterized.parameterized.expand(
88+
[
89+
("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"),
90+
]
91+
)
92+
def test_bias_gelu_fusion_unsupported_attr(
93+
self,
94+
_,
95+
test_data_constructor: OnnxFunction,
96+
expected_graph_len: int,
97+
expected_op_type: str,
98+
):
99+
self._check(test_data_constructor, expected_graph_len, expected_op_type)
100+
50101

51102
if __name__ == "__main__":
52103
unittest.main()

0 commit comments

Comments
 (0)