diff --git a/noxfile.py b/noxfile.py index 31cb10dc55..cee275ef15 100644 --- a/noxfile.py +++ b/noxfile.py @@ -42,7 +42,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.1" +ONNX_IR = "onnx_ir==0.1.3" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 378c5a7c35..97eafc4739 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -20,6 +20,7 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, + fuse_relus_clips, no_op, pattern, redundant_scatter_nd, @@ -32,6 +33,7 @@ *broadcast_to_matmul.rules.rules, *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, + *fuse_relus_clips.fuse_relus_clips_rules().rules, *basic_rules.basic_optimization_rules().rules, *redundant_scatter_nd.rules.rules, ) diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py new file mode 100644 index 0000000000..ad2fdf28ef --- /dev/null +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Relu(Relu(X)) -> Relu +- Relu(Clip(X)) -> Clip +- Clip(Relu(X)) -> Clip +- Clip(Clip(X)) -> Clip +""" + +import abc + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern as orp + + +class FuseSuccessiveRelu(orp.RewriteRuleClassBase): + """Replaces ``Relu(Relu(X))`` with ``Relu(X)``.""" + + def rewrite(self, op, x): + return op.Relu(x) + + def pattern(self, op, x): + return op.Relu(op.Relu(x)) + + +class _FuseReluClipBase(orp.RewriteRuleClassBase, abc.ABC): + def rewrite(self, op, x, **kwargs): + first_clip_node = kwargs.get("out_first_clip").producer() + second_clip_node = None + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + + min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node) + clip_min_max = [] + + if min_clip is not None: + clip_min_max.append( + op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min") + ) + + if max_clip is not None: + # ONNX Clip expects min and max inputs in order. + # If min is not provided, we insert None to maintain correct argument positions. + if min_clip is None: + clip_min_max.append(None) + + clip_min_max.append( + op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max") + ) + + return op.Clip(x, *clip_min_max) + + @abc.abstractmethod + def compute_clip_min_max( + self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None + ): + pass + + def extract_min_max(self, node: ir.Node): + # Infer dtype from node first input + dtype = node.inputs[0].dtype.numpy() + min_clip, max_clip = None, None + + if len(node.inputs) > 1: + min_input = node.inputs[1] + # If only a max is provided, min is implicitly None, so we check that + if min_input is not None: + min_clip = min_input.const_value.numpy() + + if len(node.inputs) > 2: + max_clip = node.inputs[2].const_value.numpy() + + return min_clip, max_clip, dtype + + def check(self, context, **kwargs): + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the min and max inputs of the Clip nodes are + not graph inputs and are constant values (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Unused + check_result = orp.MatchResult() + + # Check if Clip min/max are not graph inputs and are constant values + clip_min_max = [] + + first_clip_node = kwargs.get("out_first_clip").producer() + clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + clip_min_max.extend( + [inp for inp in second_clip_node.inputs[1:] if inp is not None] + ) + + for m in clip_min_max: + if m.is_graph_input(): + return check_result.fail(f"{m.name} is a graph input.") + + if ir.convenience.get_const_tensor(m) is None: + return check_result.fail(f"{m.name} is not a constant.") + + return check_result + + +class FuseSuccessiveClip(_FuseReluClipBase): + """Replaces ``Clip(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip( + op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]), + _allow_other_inputs=True, + _outputs=["out_second_clip"], + ) + + def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node): + min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node) + min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node) + + def combine(val1, val2, op): + if val1 is not None and val2 is not None: + return ir.tensor(np.array(op(val1, val2), dtype=dtype)) + elif val1 is not None: + return ir.tensor(val1) + elif val2 is not None: + return ir.tensor(val2) + return None + + min_clip = combine(min_clip1, min_clip2, np.maximum) + max_clip = combine(max_clip1, max_clip2, np.minimum) + + return min_clip, max_clip + + +class FuseSuccessiveClipRelu(_FuseReluClipBase): + """Replaces ``Clip(Relu(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"]) + + def compute_clip_min_max(self, first_clip_node: ir.Node, _): + min_clip, max_clip, dtype = self.extract_min_max(first_clip_node) + + if min_clip is None: + # The minimum clipping value is implicitly 0 (Relu clamps at 0) + min_clip = 0 + + min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype)) + + if max_clip is not None: + max_clip = ir.tensor(max_clip) + return min_clip, max_clip + + +class FuseSuccessiveReluClip(FuseSuccessiveClipRelu): + """Replaces ``Relu(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) + + +fuse_successive_relu_rule = FuseSuccessiveRelu().rule() +fuse_successive_clip_rule = FuseSuccessiveClip().rule() +fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +def fuse_relus_clips_rules() -> orp.RewriteRuleSet: + """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. + + Returns: + RewriteRuleSet + """ + + # Order is important + return orp.RewriteRuleSet( + [ + fuse_successive_clip_relu_rule, + fuse_successive_relu_clip_rule, + fuse_successive_relu_rule, + fuse_successive_clip_rule, + ] + ) diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/fuse_relus_clips_test.py new file mode 100644 index 0000000000..cb3c7c4979 --- /dev/null +++ b/onnxscript/rewriter/fuse_relus_clips_test.py @@ -0,0 +1,366 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +import parameterized +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript.rewriter import fuse_relus_clips, testing +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter.fuse_relus_clips import ( + fuse_successive_clip_relu_rule, + fuse_successive_clip_rule, + fuse_successive_relu_clip_rule, +) + + +class _FuseReluClipTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2, 32, 14), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + # onnxruntime has an optimization that fuses Clip(Relu) and + # it doesn't support int data, that's why we disable ort optimization + # see https://github.com/microsoft/onnxruntime/blob/c98a0e014b641e289ed25f42b792bca1893ccb03/onnxruntime/core/optimizer/relu_clip_fusion.cc#L60 + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: orp.RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = orp.MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class FuseSuccessiveReluTest(_FuseReluClipTestBase): + def test_successful_fuse_successive_relus(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + Y = Relu(x2) + } + """) + self.run_test(model, expected_op_types=["Relu"]) + + +class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "float", + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "float", + ), + ( + "int_relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "int32", + ), + ( + "int_clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "int32", + ), + ] + ) + def test_successful_fuse_successive_relu_clip(self, _, nodes, dtype): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} min = {{1}}, {dtype} max = {{6}}> + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + min = Constant() + Y = Clip(x1, min) + """, + ), + ( + "clip_then_relu", + """ + min = Constant() + x1 = Clip(X, min) + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_constant_nodes(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float[N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Constant", "Clip"]) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1,,max) + """, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X,,max) + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Clip"]) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + fuse_successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + fuse_successive_relu_clip_rule, + ), + ] + ) + def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + {{ + min = ReduceMean(X) + {nodes} + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + fuse_successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + fuse_successive_relu_clip_rule, + ), + ] + ) + def test_fail_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min) => (float [N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is a graph input.") + + +class FuseSuccessiveClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ("float", "float"), + ("int32", "int32"), + ] + ) + def test_successful_fuse_successive_clips(self, _, dtype): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} max1 = {{4}}, {dtype} min2 = {{0}}, + {dtype} max2 = {{11}}, {dtype} min3 = {{1}}, + {dtype} max3 = {{7}}, {dtype} max4 = {{13}}> + {{ + x1 = Clip(X) + x2 = Clip(x1,,max1) + x3 = Clip(x2, min2, max2) + x4 = Clip(x3, min3, max3) + x5 = Clip(x4,,max4) + Y = Clip(x5) + }} + """) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + def test_successful_fuse_successive_clips_node_constants(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + min1 = Constant() + max1 = Constant() + min2 = Constant() + max2 = Constant() + x1 = Clip(X, min1, max1) + Y = Clip(x1, min2, max2) + } + """) + self.run_test( + model, expected_op_types=["Constant", "Constant", "Constant", "Constant", "Clip"] + ) + + def test_successful_fuse_successive_clips_no_min(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Clip(X,, max1) + Y = Clip(x1,, max2) + } + """) + self.run_test(model, expected_op_types=["Clip"]) + + def test_fail_fuse_successive_clips_non_initializers(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + min1 = ReduceMean(X) + min2 = ReduceMax(X) + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test(model, fuse_successive_clip_rule, "is not a constant.") + + def test_fail_fuse_successive_clips_graph_inputs(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min1, float min2) => (float [N, ?, ?] Y) + + { + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test(model, fuse_successive_clip_rule, "is a graph input.") + + +class FuseReluClipIntegrationTest(_FuseReluClipTestBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + x3 = Relu(x2) + x4 = Relu(x3) + x5 = Clip(x4) + x6 = Relu(x5) + Y = Clip(x6) + } + """) + self.run_test(model, expected_op_types=["Clip"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 7c8c5175ee..89cceb1c1d 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -15,6 +15,7 @@ def assert_numerically_equal( original_model_proto: onnx.ModelProto | ir.Model, rewritten_model_proto: onnx.ModelProto | ir.Model, args: tuple[Any, ...], + ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, rtol: float = 1, atol: float = 1e-3, ): @@ -23,9 +24,10 @@ def assert_numerically_equal( Args: original_model_proto: The original model proto or ir.Model. rewritten_model_proto: The rewritten by the rules model proto or ir.Model. + args: The positional arguments to pass to the model. + ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. - args: The positional arguments to pass to the model. """ if isinstance(original_model_proto, ir.Model): @@ -37,7 +39,7 @@ def assert_numerically_equal( k.name: v for k, v in zip(original_model_proto.graph.input, args) } original_proto_ort_inference_session = _ort_session_initializer( - original_model_proto.SerializeToString() + original_model_proto.SerializeToString(), ort_optimization_level ) run_options = ort.RunOptions() run_options.log_severity_level = 3 # 3: Error @@ -49,7 +51,7 @@ def assert_numerically_equal( k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) } the_rewritten_proto_ort_inference_session = _ort_session_initializer( - rewritten_model_proto.SerializeToString() + rewritten_model_proto.SerializeToString(), ort_optimization_level ) the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( None, the_rewritten_proto_ort_inputs, run_options=run_options @@ -60,12 +62,15 @@ def assert_numerically_equal( ) -def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: +def _ort_session_initializer( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> ort.InferenceSession: """Initialize an ONNX Runtime inference session with the specified model.""" import onnxruntime as ort session_options = ort.SessionOptions() session_options.log_severity_level = 3 # 3: Error + session_options.graph_optimization_level = ort_optimization_level possible_providers = ( "CUDAExecutionProvider", "CPUExecutionProvider",