diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu.py b/onnxscript/rewriter/ort_fusions/bias_gelu.py index 203223ab87..eff36e8940 100644 --- a/onnxscript/rewriter/ort_fusions/bias_gelu.py +++ b/onnxscript/rewriter/ort_fusions/bias_gelu.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter import _fusion_utils, pattern +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern class BiasGeluFusion(pattern.RewriteRuleClassBase): @@ -22,30 +22,35 @@ def __init__( super().__init__(name) self._contrib_op = contrib_op - def pattern(self, op, x, y): - gelu_add = op.Add(x, y) + def pattern(self, op, input, bias): + gelu_add = op.Add(input, bias) + if self._contrib_op: return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"]) else: return op.Gelu(gelu_add, _outputs=["gelu"]) - def check(self, op, gelu, **_) -> pattern.MatchResult: + def check(self, op, gelu, input, bias, **_) -> pattern.MatchResult: check_result = pattern.MatchResult() approximate = gelu.producer().attributes.get_string("approximate") if approximate is not None and approximate == "tanh": return check_result.fail( "Gelu operator with 'approximate' set to 'tanh' is not supported." ) + + if not _ir_utils.has_rank(bias, 1): + return check_result.fail("bias is not of shape 1D tensor", bias) + return check_result - def rewrite(self, op, x, y, **_): - return op.BiasGelu(x, y, _domain="com.microsoft") + def rewrite(self, op, input, bias, **_): + return op.BiasGelu(input, bias, _domain="com.microsoft") bias_gelu_rules = pattern.RewriteRuleSet( [ - BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False), - BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True), + *BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False).commute(), + *BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True).commute(), ] ) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py index 7c6ecd8b9a..2a54eae852 100644 --- a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -18,27 +18,39 @@ @script() -def _test_script_onnx_default(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]: +def _test_script_onnx_default(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: gelu_add = op.Add(x, y) return op.Gelu(gelu_add) @script() -def _test_script_onnx_none(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]: +def _test_script_onnx_none(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: gelu_add = op.Add(x, y) return op.Gelu(gelu_add, approximate="none") @script() -def _test_script_onnx_unsupported(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]: +def _test_script_msft_op(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: gelu_add = op.Add(x, y) - return op.Gelu(gelu_add, approximate="tanh") + return msft_op.Gelu(gelu_add) + + +@script() +def _test_script_reversed_order(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(y, x) + return op.Gelu(gelu_add) @script() -def _test_script_msft_op(x: FLOAT[10], y: FLOAT[10]) -> FLOAT[10]: +def _test_script_onnx_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: gelu_add = op.Add(x, y) - return msft_op.Gelu(gelu_add) + return op.Gelu(gelu_add, approximate="tanh") + + +@script() +def _test_script_shape_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, x) + return op.Gelu(gelu_add) class BiasGeluFusionTest(unittest.TestCase): @@ -54,7 +66,7 @@ def _check( optimize(model) input = { - "x": np.random.randn(10).astype(np.float32), + "x": np.random.randn(10, 10).astype(np.float32), "y": np.random.randn(10).astype(np.float32), } original_output = test_utils.ort_run("Original", model, input) @@ -73,6 +85,7 @@ def _check( ("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"), ("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"), ("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"), + ("reversed_order", _test_script_reversed_order, 1, "BiasGelu"), ] ) def test_bias_gelu_fusion( @@ -87,6 +100,7 @@ def test_bias_gelu_fusion( @parameterized.parameterized.expand( [ ("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"), + ("unsupported_shape", _test_script_shape_unsupported, 2, "Add"), ] ) def test_bias_gelu_fusion_unsupported_attr(