Skip to content

[Migration][DO NOT MERGE] Support submodule functions in pattern-rewrite #1345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def output_names(self):
def attribute(self):
return self.original_node_proto.attribute

def set_version_if_custom_op(self, version_map: dict[str, int]) -> None:
if self.domain != "" and self.domain in version_map:
self.version = version_map[self.domain]

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Expand Down
5 changes: 3 additions & 2 deletions onnxscript/_legacy_ir/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model:
self._function_shape_env = visitor.FunctionShapeEnv()
self._function_shape_env.load_from_model_proto(model_proto)
self._ir_version = model_proto.ir_version
version_map = {x.domain: x.version for x in model_proto.opset_import}
self.version_map = {x.domain: x.version for x in model_proto.opset_import}
functions = [self.visit_function(function) for function in model_proto.functions]
self.functions = {function.id: function for function in functions}
graph = self.visit_graph(model_proto.graph)
model = ir.Model()
model.set(model_proto, graph, functions, version_map)
model.set(model_proto, graph, functions, self.version_map)
return model

def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph:
Expand Down Expand Up @@ -122,6 +122,7 @@ def process_initializer(self, init: onnx.TensorProto):

def process_node(self, node):
node_ir = ir.Node(node)
node_ir.set_version_if_custom_op(self.version_map)
self.current_graph_or_function.nodes.append(node_ir)
for name in node.input:
value = self.lookup(name)
Expand Down
1 change: 1 addition & 0 deletions onnxscript/_legacy_ir/protobuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def visit_ir_function(
# function_proto.metadata_props = ir_function.original_function_proto.metadata_props)

for node in ir_function.nodes:
# TODO: deduplicate the opset import of function?
operator_setid_proto = function_proto.opset_import.add()
if node.domain in self.opset_imports:
operator_setid_proto.domain = self.opset_imports[node.domain].domain
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
{
output = afunction (input_x, input_y)
output = pkg.custom.afunction (input_x, input_y)
}
<domain: "pkg.custom", opset_import: [ "" : 17]>
afunction (input_x, input_y) => (output)
Expand Down
8 changes: 7 additions & 1 deletion onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import onnx

from onnxscript._legacy_ir import irbuilder, protobuilder
from onnxscript.optimizer import remove_unused
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
transformers,
Expand All @@ -16,6 +17,8 @@
ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
*group_normalization_merge_silu.rules.rules,
]


Expand Down Expand Up @@ -49,5 +52,8 @@ def rewrite(
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir)
print(f"Applied {count} pattern rewrite rules.")
model = protobuilder.build_model_proto(model_ir)
# TODO: Does it make more sense we run DCE after each rewrite rule applied?
# If so, we need IR to support DCE.
remove_unused.remove_unused_nodes(model)
remove_unused_function.remove_unused_functions(model)
return model
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import logging

from onnxscript.rewriter import pattern

op = pattern.onnxop
msft_op = pattern.msft_op
torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)


def group_normalization_and_silu_submodule(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=0,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed)


def group_normalization_with_silu(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=1,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])


group_normalization_merge_silu_submodule_rule = pattern.RewriteRule(
group_normalization_and_silu_submodule,
group_normalization_with_silu,
)

rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule])
125 changes: 125 additions & 0 deletions onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import unittest

import numpy as np
import onnx.parser

from onnxscript._legacy_ir import irbuilder
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
)


class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase):
def test_group_norm_with_silu_submodule_is_replaced_by_group_norm(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: ["" : 17, "pkg.torch230a0git77ef9d4" : 1, "com.microsoft" : 1]>
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
{
group_norm = com.microsoft.GroupNorm <activation=0, channels_last=1, epsilon=0.000001, groups=32>(image, weight, bias)
transposed = Transpose <perm=[0, 3, 1, 2]>(group_norm)
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed)
}
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) => (output)
{
_to_copy_38 = Cast <to: int = 1> (transposed)
sigmoid_18 = Sigmoid (_to_copy_38)
mul_26 = Mul (_to_copy_38, sigmoid_18)
output = Cast <to: int = 10> (mul_26)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
weight_value = np.random.rand(320, 1, 1).astype(np.float16)
bias_value = np.random.rand(320, 1, 1).astype(np.float16)
model.graph.initializer.extend(
[
onnx.helper.make_tensor(
"weight",
onnx.TensorProto.FLOAT16,
weight_value.shape,
weight_value,
),
onnx.helper.make_tensor(
"bias",
onnx.TensorProto.FLOAT16,
bias_value.shape,
bias_value,
),
]
)

ir = irbuilder.build_ir(model)
count = group_normalization_merge_silu.rules.apply_to_model(ir)
self.assertEqual(count, 1)
# plus 2 in model constants
self.assertEqual(len(ir.graph.nodes), 2)

def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17, "pkg.torch230a0git77ef9d4" : 1]>
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
{
adjusted_input_shape = Constant<value: tensor = int64[3] {0, 32, -1}>()
image_reshape = Reshape (image, adjusted_input_shape)
instance_norm = InstanceNormalization <epsilon=0.000001>(image_reshape, weight_for_norm, bias_for_norm)
original_input_shape = Constant<value: tensor = int64[4] {1, 320, 128, 128}>()
instance_norm_reshape = Reshape (instance_norm, original_input_shape)
mul_output = Mul (instance_norm_reshape, weight_full)
add_output = Add (mul_output, bias_full)
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output)
}
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) => (output)
{
_to_copy_38 = Cast <to: int = 1> (add_output)
sigmoid_18 = Sigmoid (_to_copy_38)
mul_26 = Mul (_to_copy_38, sigmoid_18)
output = Cast <to: int = 10> (mul_26)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
weight_full_value = np.random.rand(320, 1, 1).astype(np.float16)
bias_full_value = np.random.rand(320, 1, 1).astype(np.float16)
weight_for_norm_value = np.ones(32, dtype=np.float16)
bias_for_norm_value = np.zeros(32, dtype=np.float16)

model.graph.initializer.extend(
[
onnx.helper.make_tensor(
"weight_for_norm",
onnx.TensorProto.FLOAT16,
weight_for_norm_value.shape,
weight_for_norm_value,
),
onnx.helper.make_tensor(
"bias_for_norm",
onnx.TensorProto.FLOAT16,
bias_for_norm_value.shape,
bias_for_norm_value,
),
onnx.helper.make_tensor(
"weight_full",
onnx.TensorProto.FLOAT16,
weight_full_value.shape,
weight_full_value,
),
onnx.helper.make_tensor(
"bias_full",
onnx.TensorProto.FLOAT16,
bias_full_value.shape,
bias_full_value,
),
]
)

ir = irbuilder.build_ir(model)
count = instance_to_group_normalization.rules.apply_to_model(ir)
count += group_normalization_merge_silu.rules.apply_to_model(ir)
self.assertEqual(count, 2)
# plus 2 in model constants
self.assertEqual(len(ir.graph.nodes), 10)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

op = pattern.onnxop
msft_op = pattern.msft_op
torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,4 +147,6 @@ def group_normalization(
check_if_simulated_instance_norm_is_used,
)

# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,
# so we need to run instance_norm_to_group_norm_with_silu_rule first.
rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule])
Loading