diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py index 1bd96961d9..cc1b6af17f 100644 --- a/onnxscript/_legacy_ir/__init__.py +++ b/onnxscript/_legacy_ir/__init__.py @@ -319,6 +319,10 @@ def output_names(self): def attribute(self): return self.original_node_proto.attribute + def set_version_if_custom_op(self, version_map: dict[str, int]) -> None: + if self.domain != "" and self.domain in version_map: + self.version = version_map[self.domain] + def get_attribute(self, name: str) -> int | float | None: return self.attributes.get(name, None) diff --git a/onnxscript/_legacy_ir/irbuilder.py b/onnxscript/_legacy_ir/irbuilder.py index 1bb265e23b..5bee6083bc 100644 --- a/onnxscript/_legacy_ir/irbuilder.py +++ b/onnxscript/_legacy_ir/irbuilder.py @@ -35,12 +35,12 @@ def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model: self._function_shape_env = visitor.FunctionShapeEnv() self._function_shape_env.load_from_model_proto(model_proto) self._ir_version = model_proto.ir_version - version_map = {x.domain: x.version for x in model_proto.opset_import} + self.version_map = {x.domain: x.version for x in model_proto.opset_import} functions = [self.visit_function(function) for function in model_proto.functions] self.functions = {function.id: function for function in functions} graph = self.visit_graph(model_proto.graph) model = ir.Model() - model.set(model_proto, graph, functions, version_map) + model.set(model_proto, graph, functions, self.version_map) return model def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph: @@ -122,6 +122,7 @@ def process_initializer(self, init: onnx.TensorProto): def process_node(self, node): node_ir = ir.Node(node) + node_ir.set_version_if_custom_op(self.version_map) self.current_graph_or_function.nodes.append(node_ir) for name in node.input: value = self.lookup(name) diff --git a/onnxscript/_legacy_ir/protobuilder.py b/onnxscript/_legacy_ir/protobuilder.py index 31d7eb8e21..bdaad92dee 100644 --- a/onnxscript/_legacy_ir/protobuilder.py +++ b/onnxscript/_legacy_ir/protobuilder.py @@ -70,6 +70,7 @@ def visit_ir_function( # function_proto.metadata_props = ir_function.original_function_proto.metadata_props) for node in ir_function.nodes: + # TODO: deduplicate the opset import of function? operator_setid_proto = function_proto.opset_import.add() if node.domain in self.opset_imports: operator_setid_proto.domain = self.opset_imports[node.domain].domain diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index 73cb59e635..b462c46c1d 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -36,7 +36,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) { - output = afunction (input_x, input_y) + output = pkg.custom.afunction (input_x, input_y) } afunction (input_x, input_y) => (output) diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 5a99d03749..0e6eb613a7 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -3,9 +3,10 @@ import onnx from onnxscript._legacy_ir import irbuilder, protobuilder -from onnxscript.optimizer import remove_unused +from onnxscript.optimizer import remove_unused, remove_unused_function from onnxscript.rewriter import function_rule, pattern from onnxscript.rewriter.onnxruntime import ( + group_normalization_merge_silu, instance_to_group_normalization, softmax, transformers, @@ -16,6 +17,8 @@ ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, *instance_to_group_normalization.rules.rules, + # NOTE: group normalization merge silu should be applied after instance to group normalization + *group_normalization_merge_silu.rules.rules, ] @@ -49,5 +52,8 @@ def rewrite( count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir) print(f"Applied {count} pattern rewrite rules.") model = protobuilder.build_model_proto(model_ir) + # TODO: Does it make more sense we run DCE after each rewrite rule applied? + # If so, we need IR to support DCE. remove_unused.remove_unused_nodes(model) + remove_unused_function.remove_unused_functions(model) return model diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py new file mode 100644 index 0000000000..a6dfb54eb5 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import logging + +from onnxscript.rewriter import pattern + +op = pattern.onnxop +msft_op = pattern.msft_op +torch_module_op = pattern.torch_module_op + +logger = logging.getLogger(__name__) + + +def group_normalization_and_silu_submodule( + input, + weight, + bias, + epsilon, + groups, +): + group_norm = msft_op.GroupNorm( + input, + weight, + bias, + activation=0, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) + return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed) + + +def group_normalization_with_silu( + input, + weight, + bias, + epsilon, + groups, +): + group_norm = msft_op.GroupNorm( + input, + weight, + bias, + activation=1, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + return op.Transpose(group_norm, perm=[0, 3, 1, 2]) + + +group_normalization_merge_silu_submodule_rule = pattern.RewriteRule( + group_normalization_and_silu_submodule, + group_normalization_with_silu, +) + +rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule]) diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py new file mode 100644 index 0000000000..254e526d4c --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py @@ -0,0 +1,125 @@ +import unittest + +import numpy as np +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter.onnxruntime import ( + group_normalization_merge_silu, + instance_to_group_normalization, +) + + +class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): + def test_group_norm_with_silu_submodule_is_replaced_by_group_norm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + group_norm = com.microsoft.GroupNorm (image, weight, bias) + transposed = Transpose (group_norm) + output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) + } + + torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) => (output) + { + _to_copy_38 = Cast (transposed) + sigmoid_18 = Sigmoid (_to_copy_38) + mul_26 = Mul (_to_copy_38, sigmoid_18) + output = Cast (mul_26) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_value = np.random.rand(320, 1, 1).astype(np.float16) + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight", + onnx.TensorProto.FLOAT16, + weight_value.shape, + weight_value, + ), + onnx.helper.make_tensor( + "bias", + onnx.TensorProto.FLOAT16, + bias_value.shape, + bias_value, + ), + ] + ) + + ir = irbuilder.build_ir(model) + count = group_normalization_merge_silu.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 2) + + def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + add_output = Add (mul_output, bias_full) + output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) + } + + torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) => (output) + { + _to_copy_38 = Cast (add_output) + sigmoid_18 = Sigmoid (_to_copy_38) + mul_26 = Mul (_to_copy_38, sigmoid_18) + output = Cast (mul_26) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight_for_norm", + onnx.TensorProto.FLOAT16, + weight_for_norm_value.shape, + weight_for_norm_value, + ), + onnx.helper.make_tensor( + "bias_for_norm", + onnx.TensorProto.FLOAT16, + bias_for_norm_value.shape, + bias_for_norm_value, + ), + onnx.helper.make_tensor( + "weight_full", + onnx.TensorProto.FLOAT16, + weight_full_value.shape, + weight_full_value, + ), + onnx.helper.make_tensor( + "bias_full", + onnx.TensorProto.FLOAT16, + bias_full_value.shape, + bias_full_value, + ), + ] + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + count += group_normalization_merge_silu.rules.apply_to_model(ir) + self.assertEqual(count, 2) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 10) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index a85abe1053..0f6e766858 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -11,6 +11,7 @@ op = pattern.onnxop msft_op = pattern.msft_op +torch_module_op = pattern.torch_module_op logger = logging.getLogger(__name__) @@ -146,4 +147,6 @@ def group_normalization( check_if_simulated_instance_norm_is_used, ) +# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule, +# so we need to run instance_norm_to_group_norm_with_silu_rule first. rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule]) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 1b26e82947..e89057d473 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -38,6 +38,23 @@ def to_ir(self, model, bindings=None) -> int | str | list: return self.value +class PrefixPattern: + """This pattern is used to simplify submodule opset pattern matching.""" + + def __init__(self, value: str) -> None: + self._value = value + + @property + def value(self) -> str: + return self._value + + def matches(self, value: str) -> bool: + return value.startswith(self.value) + + def to_ir(self, model, bindings=None) -> str: + raise NotImplementedError("PrefixPattern should not be converted to IR") + + class FloatConstantPattern: def __init__(self, value: float, rel_tol: float = 1e-5, abs_tol: float = 1e-8) -> None: self._value = value @@ -152,20 +169,24 @@ class OpsetPattern: def __init__( self, - domain_pattern: ConstantPattern, + domain_pattern: ConstantPattern | PrefixPattern, version_pattern: ConstantPattern | AnyPattern, ) -> None: self.domain_pattern = domain_pattern self.version_pattern = version_pattern @classmethod - def singleton(cls, domain: str, version: int): + def singleton(cls, domain: str, version: int) -> OpsetPattern: return cls(ConstantPattern(domain), ConstantPattern(version)) @classmethod def domain(cls, domain: str) -> OpsetPattern: return cls(ConstantPattern(domain), AnyPattern()) + @classmethod + def domain_prefix(cls, domain: str) -> OpsetPattern: + return cls(PrefixPattern(domain), AnyPattern()) + def matches(self, opset): domain, version = opset return self.domain_pattern.matches(domain) and self.version_pattern.matches(version) @@ -180,6 +201,10 @@ def to_ir(self, model, bindings=None) -> str: def __getattr__(self, name: str) -> Any: return OpPattern(self, ConstantPattern(name)) + def submodule(self, name: str) -> Any: + """This method is used to match against submodule ops with prefix.""" + return OpPattern(self, PrefixPattern(name)) + opset17 = OpsetPattern.singleton("", 17) @@ -187,6 +212,8 @@ def __getattr__(self, name: str) -> Any: msft_op = OpsetPattern.singleton("com.microsoft", 1) +torch_module_op = OpsetPattern.domain_prefix("pkg.torch") + class OpPattern: """A utility class to build a NodePattern. @@ -202,7 +229,11 @@ class OpPattern: """ - def __init__(self, opset_pattern: OpsetPattern, op_name_pattern: ConstantPattern) -> None: + def __init__( + self, + opset_pattern: OpsetPattern, + op_name_pattern: ConstantPattern | PrefixPattern, + ) -> None: self.opset_pattern = opset_pattern self.op_name_pattern = op_name_pattern @@ -366,7 +397,8 @@ def _make_node( outputs = [model.make_new_name() for i in range(num_outputs)] node = onnx.helper.make_node(op, inputnames, outputs, domain=domain, **attributes) newnode = ir.Node(node) - newvalues = [ir.Value(v, newnode, i) for i, v in enumerate(outputs)] + newnode.set_version_if_custom_op(model.version_map) + newvalues = [ir.Value(name=v, node=newnode, output_index=i) for i, v in enumerate(outputs)] newnode.inputs = input newnode.outputs = newvalues newnode.attributes = attributes # TODO @@ -902,9 +934,13 @@ def _apply_deltas( v.node = last_inserted last_inserted.outputs.append(v) - del nodes[i] # that assumes unused nodes are removed (see below) - for item in reversed(inserted_nodes): - nodes.insert(i, item) + del nodes[i] + + for new_node in reversed(inserted_nodes): + nodes.insert(i, new_node) + # bind the outputs to the graph + for output_name, value in zip(new_node.output_names, new_node.outputs): + graph_or_function.values[output_name] = value path_2 = True assert not to_delete or not path_2, ( @@ -919,13 +955,13 @@ def _apply_deltas( inserted_input_output = [] for nd in inserted_nodes: inserted_input_output += nd.inputs + nd.outputs - for n in deleted_nodes[0:-1]: + for old_node in deleted_nodes[0:-1]: # Delete intermediary outputs from graph that are not used as # outputs of the graph - for output in n.outputs: + for output in old_node.outputs: if not output.is_output and output not in inserted_input_output: graph_or_function.values.pop(output.name) - nodes.remove(n) + nodes.remove(old_node) for i in to_delete: position = existing_ids[i][0] @@ -956,11 +992,13 @@ def _apply_to_graph_or_function( graph_or_function: ir.Graph | ir.Function, ) -> int: count = 0 - deltas = [] marked = set() bridge = None - for i, node in enumerate(graph_or_function.nodes): - for rule in self.rules: + # NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. + # And the graph is applied in order. + for rule in self.rules: + deltas = [] + for i, node in enumerate(graph_or_function.nodes): if hasattr(rule, "pattern"): from onnxscript.rewriter.generic_pattern import ( GenericRewriteRule, @@ -998,9 +1036,8 @@ def _apply_to_graph_or_function( deltas.append((i, delta)) count += 1 - break - _apply_deltas(graph_or_function, deltas) + _apply_deltas(graph_or_function, deltas) return count def apply_to_model(self, model: ir.Model) -> int: