Skip to content

fix: pattern match gelu from contrib and onnx ops🐛 #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions onnxscript/rewriter/ort_fusions/bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,33 @@


class BiasGeluFusion(pattern.RewriteRuleClassBase):
def __init__(
self,
name,
*,
has_contrib_op: bool,
):
super().__init__(name)
self._has_contrib_op = has_contrib_op

def pattern(self, op, x, y):
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, _domain="com.microsoft")
# see: gh-2362. Match against Gelu from onnx op or contrib ops.
if self._has_contrib_op:
return op.Gelu(gelu_add, _domain="com.microsoft")
else:
return op.Gelu(gelu_add)

def rewrite(self, op, x, y):
return op.BiasGelu(x, y, _domain="com.microsoft")


_rule = BiasGeluFusion.rule()

bias_gelu_rules = pattern.RewriteRuleSet([_rule])
bias_gelu_rules = pattern.RewriteRuleSet(
[
BiasGeluFusion.rule("gelu_onnx_op", has_contrib_op=False),
BiasGeluFusion.rule("gelu_contrib_op", has_contrib_op=True),
]
)


fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)
30 changes: 23 additions & 7 deletions onnxscript/rewriter/ort_fusions/bias_gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@
import unittest

import numpy as np
import parameterized

import onnxscript
import onnxscript.ir as ir
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
from onnxscript import FLOAT, script
from onnxscript import opset18 as op
from onnxscript import opset20 as op
from onnxscript.optimizer import optimize, remove_unused_nodes
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu

msft_op = onnxscript.values.Opset("com.microsoft", 1)


class BiasGeluFusionTest(unittest.TestCase):
def test_bias_gelu_fusion(self):
@script()
def bias_gelu_model(x, y):
gelu_add = op.Add(x, y)
gelu = msft_op.Gelu(gelu_add)
return gelu
@parameterized.parameterized.expand(
[
("with_onnx_op", False),
("with_contrib_op", True),
]
)
def test_bias_gelu_fusion(self, _: str, has_contrib_op: bool):
if has_contrib_op:

@script()
def bias_gelu_model(x, y):
gelu_add = op.Add(x, y)
gelu = msft_op.Gelu(gelu_add)
return gelu

Check warning on line 34 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L32-L34

Added lines #L32 - L34 were not covered by tests
else:

@script()
def bias_gelu_model(x, y):
gelu_add = op.Add(x, y)
gelu = op.Gelu(gelu_add)
return gelu

Check warning on line 41 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L39-L41

Added lines #L39 - L41 were not covered by tests

model_proto = bias_gelu_model.to_model_proto(
input_types=[FLOAT[10], FLOAT[10]],
Expand Down
Loading