Skip to content

Commit 3475eb5

Browse files
committed
fix: pattern match gelu from contrib and onnx ops
1 parent 4e526f7 commit 3475eb5

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

onnxscript/rewriter/ort_fusions/bias_gelu.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,33 @@
66

77

88
class BiasGeluFusion(pattern.RewriteRuleClassBase):
9+
def __init__(
10+
self,
11+
name,
12+
*,
13+
has_contrib_op: bool,
14+
):
15+
super().__init__(name)
16+
self._has_contrib_op = has_contrib_op
17+
918
def pattern(self, op, x, y):
1019
gelu_add = op.Add(x, y)
11-
return op.Gelu(gelu_add, _domain="com.microsoft")
20+
# see: gh-2362. Match against Gelu from onnx op or contrib ops.
21+
if self._has_contrib_op:
22+
return op.Gelu(gelu_add, _domain="com.microsoft")
23+
else:
24+
return op.Gelu(gelu_add)
1225

1326
def rewrite(self, op, x, y):
1427
return op.BiasGelu(x, y, _domain="com.microsoft")
1528

1629

17-
_rule = BiasGeluFusion.rule()
18-
19-
bias_gelu_rules = pattern.RewriteRuleSet([_rule])
30+
bias_gelu_rules = pattern.RewriteRuleSet(
31+
[
32+
BiasGeluFusion.rule("gelu_onnx_op", has_contrib_op=False),
33+
BiasGeluFusion.rule("gelu_contrib_op", has_contrib_op=True),
34+
]
35+
)
2036

2137

2238
fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)

onnxscript/rewriter/ort_fusions/bias_gelu_test.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,41 @@
44
import unittest
55

66
import numpy as np
7+
import parameterized
78

89
import onnxscript
910
import onnxscript.ir as ir
1011
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
1112
from onnxscript import FLOAT, script
12-
from onnxscript import opset18 as op
13+
from onnxscript import opset20 as op
1314
from onnxscript.optimizer import optimize, remove_unused_nodes
1415
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
1516

1617
msft_op = onnxscript.values.Opset("com.microsoft", 1)
1718

1819

1920
class BiasGeluFusionTest(unittest.TestCase):
20-
def test_bias_gelu_fusion(self):
21-
@script()
22-
def bias_gelu_model(x, y):
23-
gelu_add = op.Add(x, y)
24-
gelu = msft_op.Gelu(gelu_add)
25-
return gelu
21+
@parameterized.parameterized.expand(
22+
[
23+
("with_onnx_op", False),
24+
("with_contrib_op", True),
25+
]
26+
)
27+
def test_bias_gelu_fusion(self, _: str, has_contrib_op: bool):
28+
if has_contrib_op:
29+
30+
@script()
31+
def bias_gelu_model(x, y):
32+
gelu_add = op.Add(x, y)
33+
gelu = msft_op.Gelu(gelu_add)
34+
return gelu
35+
else:
36+
37+
@script()
38+
def bias_gelu_model(x, y):
39+
gelu_add = op.Add(x, y)
40+
gelu = op.Gelu(gelu_add)
41+
return gelu
2642

2743
model_proto = bias_gelu_model.to_model_proto(
2844
input_types=[FLOAT[10], FLOAT[10]],

0 commit comments

Comments
 (0)