Skip to content

fix: check for rank of bias in bias-gelu fusion🐛 #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions onnxscript/rewriter/ort_fusions/bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter import _fusion_utils, pattern
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern


class BiasGeluFusion(pattern.RewriteRuleClassBase):
Expand All @@ -22,30 +22,35 @@ def __init__(
super().__init__(name)
self._contrib_op = contrib_op

def pattern(self, op, x, y):
gelu_add = op.Add(x, y)
def pattern(self, op, input, bias):
gelu_add = op.Add(input, bias)

if self._contrib_op:
return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"])
else:
return op.Gelu(gelu_add, _outputs=["gelu"])

def check(self, op, gelu, **_) -> pattern.MatchResult:
def check(self, op, gelu, input, bias, **_) -> pattern.MatchResult:
check_result = pattern.MatchResult()
approximate = gelu.producer().attributes.get_string("approximate")
if approximate is not None and approximate == "tanh":
return check_result.fail(
"Gelu operator with 'approximate' set to 'tanh' is not supported."
)

if not _ir_utils.has_rank(bias, 1):
return check_result.fail("bias is not of shape 1D tensor", bias)

return check_result

def rewrite(self, op, x, y, **_):
return op.BiasGelu(x, y, _domain="com.microsoft")
def rewrite(self, op, input, bias, **_):
return op.BiasGelu(input, bias, _domain="com.microsoft")


bias_gelu_rules = pattern.RewriteRuleSet(
[
BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False),
BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True),
*BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False).commute(),
*BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True).commute(),
]
)

Expand Down
28 changes: 21 additions & 7 deletions onnxscript/rewriter/ort_fusions/bias_gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,39 @@


@script()
def _test_script_onnx_default(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
def _test_script_onnx_default(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add)


@script()
def _test_script_onnx_none(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
def _test_script_onnx_none(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, approximate="none")


@script()
def _test_script_onnx_unsupported(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
def _test_script_msft_op(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, approximate="tanh")
return msft_op.Gelu(gelu_add)

Check warning on line 35 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L35

Added line #L35 was not covered by tests


@script()
def _test_script_reversed_order(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(y, x)
return op.Gelu(gelu_add)

Check warning on line 41 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L40-L41

Added lines #L40 - L41 were not covered by tests


@script()
def _test_script_msft_op(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]:
def _test_script_onnx_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, y)
return msft_op.Gelu(gelu_add)
return op.Gelu(gelu_add, approximate="tanh")

Check warning on line 47 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L47

Added line #L47 was not covered by tests


@script()
def _test_script_shape_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]:
gelu_add = op.Add(x, x)
return op.Gelu(gelu_add)

Check warning on line 53 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L52-L53

Added lines #L52 - L53 were not covered by tests


class BiasGeluFusionTest(unittest.TestCase):
Expand All @@ -54,7 +66,7 @@
optimize(model)

input = {
"x": np.random.randn(10).astype(np.float32),
"x": np.random.randn(10, 10).astype(np.float32),
"y": np.random.randn(10).astype(np.float32),
}
original_output = test_utils.ort_run("Original", model, input)
Expand All @@ -73,6 +85,7 @@
("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"),
("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"),
("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"),
("reversed_order", _test_script_reversed_order, 1, "BiasGelu"),
]
)
def test_bias_gelu_fusion(
Expand All @@ -87,6 +100,7 @@
@parameterized.parameterized.expand(
[
("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"),
("unsupported_shape", _test_script_shape_unsupported, 2, "Add"),
]
)
def test_bias_gelu_fusion_unsupported_attr(
Expand Down
Loading