From c79e4f4858118548fd571ee571f87337f7e1ac1e Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Sun, 22 Jun 2025 14:41:16 +0200 Subject: [PATCH 1/7] [Rewriter]: fuse successive Relu/Clip nodes - Relu(Relu(X)) -> Relu - Relu(Clip(X)) -> Clip - Clip(Relu(X)) -> Clip - Clip(Clip(X)) -> Clip --- onnxscript/rewriter/fuse_relus_clips.py | 179 ++++++++++ onnxscript/rewriter/fuse_relus_clips_test.py | 326 +++++++++++++++++++ onnxscript/rewriter/testing.py | 11 +- 3 files changed, 513 insertions(+), 3 deletions(-) create mode 100644 onnxscript/rewriter/fuse_relus_clips.py create mode 100644 onnxscript/rewriter/fuse_relus_clips_test.py diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py new file mode 100644 index 0000000000..ed4c2cb4e3 --- /dev/null +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Relu(Relu(X)) -> Relu +- Relu(Clip(X)) -> Clip +- Clip(Relu(X)) -> Clip +- Clip(Clip(X)) -> Clip +""" + +import abc + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import pattern as orp + + +class FuseSuccessiveRelu(orp.RewriteRuleClassBase): + """Replaces ``Relu(Relu(X))`` with ``Relu(X)``.""" + + def rewrite(self, op, x): + return op.Relu(x) + + def pattern(self, op, x): + return op.Relu(op.Relu(x)) + + +class _FuseReluClipBase(orp.RewriteRuleClassBase, abc.ABC): + def rewrite(self, op, x, **kwargs): + first_clip_node = kwargs.get("out_first_clip").producer() + second_clip_node = None + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + + min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node) + clip_initializers = [] + + if min_clip is not None: + clip_initializers.append( + op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min") + ) + + if max_clip is not None: + # ONNX Clip expects min and max inputs in order. + # If min is not provided, we insert None to maintain correct argument positions. + if min_clip is None: + clip_initializers.append(None) + + clip_initializers.append( + op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max") + ) + + return op.Clip(x, *clip_initializers) + + @abc.abstractmethod + def compute_clip_min_max( + self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None + ): + pass + + def extract_min_max(self, node: ir.Node): + # Infer dtype from node first input + dtype = node.inputs[0].dtype.numpy() + min_clip, max_clip = None, None + + if len(node.inputs) > 1: + min_input = node.inputs[1] + # If only a max is provided, min is implicitly None, so we check that + if min_input is not None: + min_clip = min_input.const_value.numpy() + + if len(node.inputs) > 2: + max_clip = node.inputs[2].const_value.numpy() + + return min_clip, max_clip, dtype + + def check(self, context, **kwargs): + del context # Unused + check_result = orp.MatchResult() + + # check clip min/max are initializers + initializers = [] + + first_clip_node = kwargs.get("out_first_clip").producer() + initializers.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + initializers.extend( + [inp for inp in second_clip_node.inputs[1:] if inp is not None] + ) + + for initializer in initializers: + if initializer.is_graph_input(): + return check_result.fail(f"{initializer.name} is a graph input.") + + if not initializer.is_initializer() or initializer.const_value is None: + return check_result.fail(f"{initializer.name} is not a constant initializer.") + + return check_result + + +class FuseSuccessiveClip(_FuseReluClipBase): + """Replaces ``Clip(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip( + op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]), + _allow_other_inputs=True, + _outputs=["out_second_clip"], + ) + + def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node): + min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node) + min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node) + + if min_clip1 is not None and min_clip2 is not None: + min_clip = ir.tensor(np.array(np.maximum(min_clip1, min_clip2), dtype=dtype)) + else: + min_clip = min_clip1 if min_clip1 is not None else min_clip2 + + if max_clip1 is not None and max_clip2 is not None: + max_clip = ir.tensor(np.array(np.minimum(max_clip1, max_clip2), dtype=dtype)) + else: + max_clip = max_clip1 if max_clip1 is not None else max_clip2 + + return min_clip, max_clip + + +class FuseSuccessiveClipRelu(_FuseReluClipBase): + """Replaces ``Clip(Relu(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"]) + + def compute_clip_min_max(self, first_clip_node: ir.Node, _): + min_clip, max_clip, dtype = self.extract_min_max(first_clip_node) + + if min_clip is None: + # The minimum clipping value is implicitly 0 (Relu clamps at 0) + min_clip = 0 + + min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype)) + + if max_clip is not None: + max_clip = ir.tensor(max_clip) + return min_clip, max_clip + + +class FuseSuccessiveReluClip(FuseSuccessiveClipRelu): + """Replaces ``Relu(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) + + +fuse_successive_relu_rule = FuseSuccessiveRelu().rule() +fuse_successive_clip_rule = FuseSuccessiveClip().rule() +fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +fuse_sucessive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +def fuse_relus_clips() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. + + Returns: + RewriteRuleSet + """ + + # Order is important + return orp.RewriteRuleSet( + [ + fuse_successive_clip_relu_rule, + fuse_sucessive_relu_clip_rule, + fuse_successive_relu_rule, + fuse_successive_clip_rule, + ] + ) diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/fuse_relus_clips_test.py new file mode 100644 index 0000000000..90fdf706d2 --- /dev/null +++ b/onnxscript/rewriter/fuse_relus_clips_test.py @@ -0,0 +1,326 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnxruntime as ort +import parameterized +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript import ir +from onnxscript.rewriter import fuse_relus_clips, testing +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter.fuse_relus_clips import ( + fuse_successive_clip_relu_rule, + fuse_successive_clip_rule, + fuse_sucessive_relu_clip_rule, +) + + +class _FuseReluClipTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: onnx.ModelProto | ir.Model, + expected_op_type: str, + dtype: str = "float", + ): + base_model = ir.serde.deserialize_model(base_model) + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + + updated_model = self.clone_model(base_model) + _ = fuse_relus_clips.fuse_relus_clips().apply_to_model(updated_model) + + # Check Relu/Clip are fused + self.assertEqual(len(updated_model.graph), 1) + self.assertEqual(updated_model.graph[0].op_type, expected_op_type) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2, 32, 14), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + # onnxruntime has an optimization that fuses Clip(Relu) and + # it doesn't support int data, that's why we disable ort optimization + # see https://github.com/microsoft/onnxruntime/blob/c98a0e014b641e289ed25f42b792bca1893ccb03/onnxruntime/core/optimizer/relu_clip_fusion.cc#L60 + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: onnx.ModelProto, + rewrite_rule: orp.RewriteRule, + expected_message: str, + ): + base_model = ir.serde.deserialize_model(base_model) + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = orp.MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class FuseSuccessiveReluTest(_FuseReluClipTestBase): + def test_fuse_succesive_relus(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + Y = Relu(x2) + } + """) + self.run_test(model_proto, expected_op_type="Relu") + + +class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "float", + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "float", + ), + ( + "int_relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "int32", + ), + ( + "int_clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "int32", + ), + ] + ) + def test_fuse_successive_relu_clip(self, _, nodes, dtype): + model_proto = onnx.parser.parse_model(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} min = {{1}}, {dtype} max = {{6}}> + {{ + {nodes} + }} + """) + self.run_test(model_proto, expected_op_type="Clip", dtype=dtype) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1,,max) + """, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X,,max) + Y = Relu(x1) + """, + ), + ] + ) + def test_fuse_successive_relu_clip_no_min(self, _, nodes): + model_proto = onnx.parser.parse_model(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + {{ + {nodes} + }} + """) + self.run_test(model_proto, expected_op_type="Clip") + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + fuse_successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + fuse_sucessive_relu_clip_rule, + ), + ] + ) + def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): + model_proto = onnx.parser.parse_model(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + {{ + min = ReduceMean(X) + {nodes} + }} + """) + self.run_failed_condition_test( + model_proto, rewrite_rule, "is not a constant initializer." + ) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + fuse_successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + fuse_sucessive_relu_clip_rule, + ), + ] + ) + def test_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): + model_proto = onnx.parser.parse_model(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min) => (float [N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_failed_condition_test(model_proto, rewrite_rule, "is a graph input.") + + +class FuseSuccessiveClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ("float", "float"), + ("int32", "int32"), + ] + ) + def test_fuse_succesive_clips(self, _, dtype): + model_proto = onnx.parser.parse_model(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} min1 = {{0}}, {dtype} max1 = {{4}}, + {dtype} min2 = {{1}}, {dtype} max2 = {{11}}, + {dtype} min3 = {{3}}, {dtype} max3 = {{7}}> + {{ + x1 = Clip(X, min1, max1) + x2 = Clip(x1, min2, max2) + Y = Clip(x2, min3, max3) + }} + """) + self.run_test(model_proto, expected_op_type="Clip", dtype=dtype) + + def test_fuse_succesive_clips_no_min(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Clip(X,, max1) + Y = Clip(x1,, max2) + } + """) + self.run_test(model_proto, expected_op_type="Clip") + + def test_fuse_successive_clips_non_initializers(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + min1 = ReduceMean(X) + min2 = ReduceMax(X) + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test( + model_proto, fuse_successive_clip_rule, "is not a constant initializer." + ) + + def test_fuse_successive_clips_graph_inputs(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min1, float min2) => (float [N, ?, ?] Y) + + { + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test( + model_proto, fuse_successive_clip_rule, "is a graph input." + ) + + +class FuseReluClipIntegrationTest(_FuseReluClipTestBase): + def test_full_chain_fusion(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + x3 = Relu(x2) + x4 = Relu(x3) + x5 = Clip(x4) + x6 = Relu(x5) + Y = Clip(x6) + } + """) + self.run_test(model_proto, expected_op_type="Clip") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 7c8c5175ee..52eb3ee894 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -15,6 +15,7 @@ def assert_numerically_equal( original_model_proto: onnx.ModelProto | ir.Model, rewritten_model_proto: onnx.ModelProto | ir.Model, args: tuple[Any, ...], + ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, rtol: float = 1, atol: float = 1e-3, ): @@ -23,6 +24,7 @@ def assert_numerically_equal( Args: original_model_proto: The original model proto or ir.Model. rewritten_model_proto: The rewritten by the rules model proto or ir.Model. + ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. args: The positional arguments to pass to the model. @@ -37,7 +39,7 @@ def assert_numerically_equal( k.name: v for k, v in zip(original_model_proto.graph.input, args) } original_proto_ort_inference_session = _ort_session_initializer( - original_model_proto.SerializeToString() + original_model_proto.SerializeToString(), ort_optimization_level ) run_options = ort.RunOptions() run_options.log_severity_level = 3 # 3: Error @@ -49,7 +51,7 @@ def assert_numerically_equal( k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) } the_rewritten_proto_ort_inference_session = _ort_session_initializer( - rewritten_model_proto.SerializeToString() + rewritten_model_proto.SerializeToString(), ort_optimization_level ) the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( None, the_rewritten_proto_ort_inputs, run_options=run_options @@ -60,12 +62,15 @@ def assert_numerically_equal( ) -def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: +def _ort_session_initializer( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> ort.InferenceSession: """Initialize an ONNX Runtime inference session with the specified model.""" import onnxruntime as ort session_options = ort.SessionOptions() session_options.log_severity_level = 3 # 3: Error + session_options.graph_optimization_level = ort_optimization_level possible_providers = ( "CUDAExecutionProvider", "CPUExecutionProvider", From 59aa3cb49d21f455dbc2392830526c98a727aaff Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 24 Jun 2025 01:39:26 +0200 Subject: [PATCH 2/7] review: use ir APIs --- onnxscript/rewriter/fuse_relus_clips.py | 6 +- onnxscript/rewriter/fuse_relus_clips_test.py | 58 +++++++++----------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py index ed4c2cb4e3..53c9b4712b 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -158,10 +158,10 @@ def pattern(self, op, x): fuse_successive_relu_rule = FuseSuccessiveRelu().rule() fuse_successive_clip_rule = FuseSuccessiveClip().rule() fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() -fuse_sucessive_relu_clip_rule = FuseSuccessiveReluClip().rule() +fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() -def fuse_relus_clips() -> orp.RewriteRuleSet: +def fuse_relus_clips_rules() -> orp.RewriteRuleSet: """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. Returns: @@ -172,7 +172,7 @@ def fuse_relus_clips() -> orp.RewriteRuleSet: return orp.RewriteRuleSet( [ fuse_successive_clip_relu_rule, - fuse_sucessive_relu_clip_rule, + fuse_successive_relu_clip_rule, fuse_successive_relu_rule, fuse_successive_clip_rule, ] diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/fuse_relus_clips_test.py index 90fdf706d2..d3cb74db46 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/fuse_relus_clips_test.py @@ -14,7 +14,7 @@ from onnxscript.rewriter.fuse_relus_clips import ( fuse_successive_clip_relu_rule, fuse_successive_clip_rule, - fuse_sucessive_relu_clip_rule, + fuse_successive_relu_clip_rule, ) @@ -28,16 +28,15 @@ def clone_model(self, model: ir.Model) -> ir.Model: def run_test( self, - base_model: onnx.ModelProto | ir.Model, + base_model: ir.Model, expected_op_type: str, dtype: str = "float", ): - base_model = ir.serde.deserialize_model(base_model) onnx_checker.CheckerPass(True)(base_model) base_model = shape_inference.infer_shapes(base_model) updated_model = self.clone_model(base_model) - _ = fuse_relus_clips.fuse_relus_clips().apply_to_model(updated_model) + _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) # Check Relu/Clip are fused self.assertEqual(len(updated_model.graph), 1) @@ -64,11 +63,10 @@ def run_test( def run_failed_condition_test( self, - base_model: onnx.ModelProto, + base_model: ir.Model, rewrite_rule: orp.RewriteRule, expected_message: str, ): - base_model = ir.serde.deserialize_model(base_model) onnx_checker.CheckerPass(True)(base_model) updated_model = self.clone_model(base_model) @@ -86,7 +84,7 @@ def run_failed_condition_test( class FuseSuccessiveReluTest(_FuseReluClipTestBase): def test_fuse_succesive_relus(self): - model_proto = onnx.parser.parse_model(""" + model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) { @@ -95,7 +93,7 @@ def test_fuse_succesive_relus(self): Y = Relu(x2) } """) - self.run_test(model_proto, expected_op_type="Relu") + self.run_test(model, expected_op_type="Relu") class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): @@ -136,7 +134,7 @@ class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): ] ) def test_fuse_successive_relu_clip(self, _, nodes, dtype): - model_proto = onnx.parser.parse_model(f""" + model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) <{dtype} min = {{1}}, {dtype} max = {{6}}> @@ -144,7 +142,7 @@ def test_fuse_successive_relu_clip(self, _, nodes, dtype): {nodes} }} """) - self.run_test(model_proto, expected_op_type="Clip", dtype=dtype) + self.run_test(model, expected_op_type="Clip", dtype=dtype) @parameterized.parameterized.expand( [ @@ -165,7 +163,7 @@ def test_fuse_successive_relu_clip(self, _, nodes, dtype): ] ) def test_fuse_successive_relu_clip_no_min(self, _, nodes): - model_proto = onnx.parser.parse_model(f""" + model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -173,7 +171,7 @@ def test_fuse_successive_relu_clip_no_min(self, _, nodes): {nodes} }} """) - self.run_test(model_proto, expected_op_type="Clip") + self.run_test(model, expected_op_type="Clip") @parameterized.parameterized.expand( [ @@ -191,12 +189,12 @@ def test_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Clip(X, min) Y = Relu(x1) """, - fuse_sucessive_relu_clip_rule, + fuse_successive_relu_clip_rule, ), ] ) def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): - model_proto = onnx.parser.parse_model(f""" + model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) {{ @@ -204,9 +202,7 @@ def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule {nodes} }} """) - self.run_failed_condition_test( - model_proto, rewrite_rule, "is not a constant initializer." - ) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant initializer.") @parameterized.parameterized.expand( [ @@ -224,19 +220,19 @@ def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule x1 = Clip(X, min) Y = Relu(x1) """, - fuse_sucessive_relu_clip_rule, + fuse_successive_relu_clip_rule, ), ] ) def test_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): - model_proto = onnx.parser.parse_model(f""" + model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X, float min) => (float [N, ?, ?] Y) {{ {nodes} }} """) - self.run_failed_condition_test(model_proto, rewrite_rule, "is a graph input.") + self.run_failed_condition_test(model, rewrite_rule, "is a graph input.") class FuseSuccessiveClipTest(_FuseReluClipTestBase): @@ -247,7 +243,7 @@ class FuseSuccessiveClipTest(_FuseReluClipTestBase): ] ) def test_fuse_succesive_clips(self, _, dtype): - model_proto = onnx.parser.parse_model(f""" + model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) <{dtype} min1 = {{0}}, {dtype} max1 = {{4}}, @@ -259,10 +255,10 @@ def test_fuse_succesive_clips(self, _, dtype): Y = Clip(x2, min3, max3) }} """) - self.run_test(model_proto, expected_op_type="Clip", dtype=dtype) + self.run_test(model, expected_op_type="Clip", dtype=dtype) def test_fuse_succesive_clips_no_min(self): - model_proto = onnx.parser.parse_model(""" + model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -271,10 +267,10 @@ def test_fuse_succesive_clips_no_min(self): Y = Clip(x1,, max2) } """) - self.run_test(model_proto, expected_op_type="Clip") + self.run_test(model, expected_op_type="Clip") def test_fuse_successive_clips_non_initializers(self): - model_proto = onnx.parser.parse_model(""" + model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -286,11 +282,11 @@ def test_fuse_successive_clips_non_initializers(self): } """) self.run_failed_condition_test( - model_proto, fuse_successive_clip_rule, "is not a constant initializer." + model, fuse_successive_clip_rule, "is not a constant initializer." ) def test_fuse_successive_clips_graph_inputs(self): - model_proto = onnx.parser.parse_model(""" + model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X, float min1, float min2) => (float [N, ?, ?] Y) @@ -299,14 +295,12 @@ def test_fuse_successive_clips_graph_inputs(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test( - model_proto, fuse_successive_clip_rule, "is a graph input." - ) + self.run_failed_condition_test(model, fuse_successive_clip_rule, "is a graph input.") class FuseReluClipIntegrationTest(_FuseReluClipTestBase): def test_full_chain_fusion(self): - model_proto = onnx.parser.parse_model(""" + model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) { @@ -319,7 +313,7 @@ def test_full_chain_fusion(self): Y = Clip(x6) } """) - self.run_test(model_proto, expected_op_type="Clip") + self.run_test(model, expected_op_type="Clip") if __name__ == "__main__": From 597f4527840dacbb4866e819eadb01646eef4c93 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 24 Jun 2025 01:39:48 +0200 Subject: [PATCH 3/7] review: add rewriter to default rules --- onnxscript/rewriter/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 7e43f44032..a94e0ffece 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -19,6 +19,7 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, + fuse_relus_clips, no_op, pattern, ) @@ -29,6 +30,7 @@ *broadcast_to_matmul.rules.rules, *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, + *fuse_relus_clips.fuse_relus_clips_rules().rules, *basic_rules.basic_optimization_rules().rules, ) From e30950c5ae72d3a290b56777db042dfe4f9b26e3 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Sun, 29 Jun 2025 12:22:46 +0200 Subject: [PATCH 4/7] review: improve code and docs --- onnxscript/rewriter/fuse_relus_clips.py | 39 ++++---- onnxscript/rewriter/fuse_relus_clips_test.py | 95 ++++++++++++++------ onnxscript/rewriter/testing.py | 2 +- 3 files changed, 94 insertions(+), 42 deletions(-) diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py index 53c9b4712b..2a2999809e 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -10,8 +10,8 @@ import abc import numpy as np +import onnx_ir as ir -from onnxscript import ir from onnxscript.rewriter import pattern as orp @@ -34,10 +34,10 @@ def rewrite(self, op, x, **kwargs): second_clip_node = out_second_clip.producer() min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node) - clip_initializers = [] + clip_min_max = [] if min_clip is not None: - clip_initializers.append( + clip_min_max.append( op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min") ) @@ -45,13 +45,13 @@ def rewrite(self, op, x, **kwargs): # ONNX Clip expects min and max inputs in order. # If min is not provided, we insert None to maintain correct argument positions. if min_clip is None: - clip_initializers.append(None) + clip_min_max.append(None) - clip_initializers.append( + clip_min_max.append( op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max") ) - return op.Clip(x, *clip_initializers) + return op.Clip(x, *clip_min_max) @abc.abstractmethod def compute_clip_min_max( @@ -76,27 +76,36 @@ def extract_min_max(self, node: ir.Node): return min_clip, max_clip, dtype def check(self, context, **kwargs): + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the min and max inputs of the Clip nodes are + not graph inputs and are constant values (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ del context # Unused check_result = orp.MatchResult() - # check clip min/max are initializers - initializers = [] + # Check if clip min/max are initializers + clip_min_max = [] first_clip_node = kwargs.get("out_first_clip").producer() - initializers.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) + clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) if out_second_clip := kwargs.get("out_second_clip"): second_clip_node = out_second_clip.producer() - initializers.extend( + clip_min_max.extend( [inp for inp in second_clip_node.inputs[1:] if inp is not None] ) - for initializer in initializers: - if initializer.is_graph_input(): - return check_result.fail(f"{initializer.name} is a graph input.") + for m in clip_min_max: + if m.is_graph_input(): + return check_result.fail(f"{m.name} is a graph input.") - if not initializer.is_initializer() or initializer.const_value is None: - return check_result.fail(f"{initializer.name} is not a constant initializer.") + if ir.convenience.get_const_tensor(m) is None: + return check_result.fail(f"{m.name} is not a constant.") return check_result diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/fuse_relus_clips_test.py index d3cb74db46..22ef26a9f8 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/fuse_relus_clips_test.py @@ -4,11 +4,11 @@ import numpy as np import onnx +import onnx_ir as ir import onnxruntime as ort import parameterized from onnx_ir.passes.common import onnx_checker, shape_inference -from onnxscript import ir from onnxscript.rewriter import fuse_relus_clips, testing from onnxscript.rewriter import pattern as orp from onnxscript.rewriter.fuse_relus_clips import ( @@ -29,18 +29,16 @@ def clone_model(self, model: ir.Model) -> ir.Model: def run_test( self, base_model: ir.Model, - expected_op_type: str, + expected_op_types: list[str], dtype: str = "float", ): onnx_checker.CheckerPass(True)(base_model) base_model = shape_inference.infer_shapes(base_model) - updated_model = self.clone_model(base_model) _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) - # Check Relu/Clip are fused - self.assertEqual(len(updated_model.graph), 1) - self.assertEqual(updated_model.graph[0].op_type, expected_op_type) + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) # Check inference inputs = (self.rng.integers(low=-10, high=10, size=(2, 32, 14), dtype=np.int32),) @@ -83,7 +81,7 @@ def run_failed_condition_test( class FuseSuccessiveReluTest(_FuseReluClipTestBase): - def test_fuse_succesive_relus(self): + def test_successful_fuse_successive_relus(self): model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -93,7 +91,7 @@ def test_fuse_succesive_relus(self): Y = Relu(x2) } """) - self.run_test(model, expected_op_type="Relu") + self.run_test(model, expected_op_types=["Relu"]) class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): @@ -133,7 +131,7 @@ class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): ), ] ) - def test_fuse_successive_relu_clip(self, _, nodes, dtype): + def test_successful_fuse_successive_relu_clip(self, _, nodes, dtype): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) @@ -142,7 +140,37 @@ def test_fuse_successive_relu_clip(self, _, nodes, dtype): {nodes} }} """) - self.run_test(model, expected_op_type="Clip", dtype=dtype) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + min = Constant() + Y = Clip(x1, min) + """, + ), + ( + "clip_then_relu", + """ + min = Constant() + x1 = Clip(X, min) + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_constant_nodes(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float[N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Constant", "Clip"]) @parameterized.parameterized.expand( [ @@ -162,7 +190,7 @@ def test_fuse_successive_relu_clip(self, _, nodes, dtype): ), ] ) - def test_fuse_successive_relu_clip_no_min(self, _, nodes): + def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -171,7 +199,7 @@ def test_fuse_successive_relu_clip_no_min(self, _, nodes): {nodes} }} """) - self.run_test(model, expected_op_type="Clip") + self.run_test(model, expected_op_types=["Clip"]) @parameterized.parameterized.expand( [ @@ -193,7 +221,7 @@ def test_fuse_successive_relu_clip_no_min(self, _, nodes): ), ] ) - def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): + def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -202,7 +230,7 @@ def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule {nodes} }} """) - self.run_failed_condition_test(model, rewrite_rule, "is not a constant initializer.") + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") @parameterized.parameterized.expand( [ @@ -224,7 +252,7 @@ def test_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule ), ] ) - def test_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): + def test_fail_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X, float min) => (float [N, ?, ?] Y) @@ -242,7 +270,7 @@ class FuseSuccessiveClipTest(_FuseReluClipTestBase): ("int32", "int32"), ] ) - def test_fuse_succesive_clips(self, _, dtype): + def test_successful_fuse_successive_clips(self, _, dtype): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) @@ -255,9 +283,26 @@ def test_fuse_succesive_clips(self, _, dtype): Y = Clip(x2, min3, max3) }} """) - self.run_test(model, expected_op_type="Clip", dtype=dtype) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + def test_successful_fuse_successive_clips_node_constants(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + min1 = Constant() + max1 = Constant() + min2 = Constant() + max2 = Constant() + x1 = Clip(X, min1, max1) + Y = Clip(x1, min2, max2) + } + """) + self.run_test( + model, expected_op_types=["Constant", "Constant", "Constant", "Constant", "Clip"] + ) - def test_fuse_succesive_clips_no_min(self): + def test_successful_fuse_successive_clips_no_min(self): model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -267,9 +312,9 @@ def test_fuse_succesive_clips_no_min(self): Y = Clip(x1,, max2) } """) - self.run_test(model, expected_op_type="Clip") + self.run_test(model, expected_op_types=["Clip"]) - def test_fuse_successive_clips_non_initializers(self): + def test_fail_fuse_successive_clips_non_initializers(self): model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -281,11 +326,9 @@ def test_fuse_successive_clips_non_initializers(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test( - model, fuse_successive_clip_rule, "is not a constant initializer." - ) + self.run_failed_condition_test(model, fuse_successive_clip_rule, "is not a constant.") - def test_fuse_successive_clips_graph_inputs(self): + def test_fail_fuse_successive_clips_graph_inputs(self): model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X, float min1, float min2) => (float [N, ?, ?] Y) @@ -299,7 +342,7 @@ def test_fuse_successive_clips_graph_inputs(self): class FuseReluClipIntegrationTest(_FuseReluClipTestBase): - def test_full_chain_fusion(self): + def test_successful_full_chain_fusion(self): model = ir.from_onnx_text(""" < ir_version: 10, opset_import: ["" : 20] > test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) @@ -313,7 +356,7 @@ def test_full_chain_fusion(self): Y = Clip(x6) } """) - self.run_test(model, expected_op_type="Clip") + self.run_test(model, expected_op_types=["Clip"]) if __name__ == "__main__": diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 52eb3ee894..89cceb1c1d 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -24,10 +24,10 @@ def assert_numerically_equal( Args: original_model_proto: The original model proto or ir.Model. rewritten_model_proto: The rewritten by the rules model proto or ir.Model. + args: The positional arguments to pass to the model. ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. - args: The positional arguments to pass to the model. """ if isinstance(original_model_proto, ir.Model): From a5e45ad98fcd3e3480601acd6882af5e85697e62 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 1 Jul 2025 07:52:04 -0700 Subject: [PATCH 5/7] Update noxfile.py --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 31cb10dc55..cee275ef15 100644 --- a/noxfile.py +++ b/noxfile.py @@ -42,7 +42,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.1" +ONNX_IR = "onnx_ir==0.1.3" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" From b23b814cdafbb0ebc518db77d5b7880c1d11c006 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 1 Jul 2025 17:40:23 +0200 Subject: [PATCH 6/7] review: update comment --- onnxscript/rewriter/fuse_relus_clips.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py index 2a2999809e..36b2aaeb56 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -88,7 +88,7 @@ def check(self, context, **kwargs): del context # Unused check_result = orp.MatchResult() - # Check if clip min/max are initializers + # Check if Clip min/max are not graph inputs and are constant values clip_min_max = [] first_clip_node = kwargs.get("out_first_clip").producer() From 10a43f9fd0abb540c0fad1e1767395b3e3b68cbc Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 1 Jul 2025 19:59:27 +0200 Subject: [PATCH 7/7] fix(FuseSuccessiveClip): fix compute_clip_min_max --- onnxscript/rewriter/fuse_relus_clips.py | 20 +++++++++++--------- onnxscript/rewriter/fuse_relus_clips_test.py | 15 +++++++++------ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py index 36b2aaeb56..ad2fdf28ef 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -124,15 +124,17 @@ def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.No min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node) min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node) - if min_clip1 is not None and min_clip2 is not None: - min_clip = ir.tensor(np.array(np.maximum(min_clip1, min_clip2), dtype=dtype)) - else: - min_clip = min_clip1 if min_clip1 is not None else min_clip2 - - if max_clip1 is not None and max_clip2 is not None: - max_clip = ir.tensor(np.array(np.minimum(max_clip1, max_clip2), dtype=dtype)) - else: - max_clip = max_clip1 if max_clip1 is not None else max_clip2 + def combine(val1, val2, op): + if val1 is not None and val2 is not None: + return ir.tensor(np.array(op(val1, val2), dtype=dtype)) + elif val1 is not None: + return ir.tensor(val1) + elif val2 is not None: + return ir.tensor(val2) + return None + + min_clip = combine(min_clip1, min_clip2, np.maximum) + max_clip = combine(max_clip1, max_clip2, np.minimum) return min_clip, max_clip diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/fuse_relus_clips_test.py index 22ef26a9f8..cb3c7c4979 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/fuse_relus_clips_test.py @@ -274,13 +274,16 @@ def test_successful_fuse_successive_clips(self, _, dtype): model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 20] > test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) - <{dtype} min1 = {{0}}, {dtype} max1 = {{4}}, - {dtype} min2 = {{1}}, {dtype} max2 = {{11}}, - {dtype} min3 = {{3}}, {dtype} max3 = {{7}}> + <{dtype} max1 = {{4}}, {dtype} min2 = {{0}}, + {dtype} max2 = {{11}}, {dtype} min3 = {{1}}, + {dtype} max3 = {{7}}, {dtype} max4 = {{13}}> {{ - x1 = Clip(X, min1, max1) - x2 = Clip(x1, min2, max2) - Y = Clip(x2, min3, max3) + x1 = Clip(X) + x2 = Clip(x1,,max1) + x3 = Clip(x2, min2, max2) + x4 = Clip(x3, min3, max3) + x5 = Clip(x4,,max4) + Y = Clip(x5) }} """) self.run_test(model, expected_op_types=["Clip"], dtype=dtype)