Skip to content

Commit 044dd85

Browse files
committed
[Migration][DO NOT MERGE] Support submodule functions in pattern-rewrite
From https://github.com/microsoft/onnx-rewriter/commit/d0e2876f2e6765738a1d9ad60b1d55000ffb50f7 Co-authored-by: Ti-Tai Wang <titaiwangmicrosoft.com> ghstack-source-id: fe6aba8 Pull Request resolved: #1345
1 parent 125ce9f commit 044dd85

File tree

9 files changed

+254
-19
lines changed

9 files changed

+254
-19
lines changed

onnxscript/_legacy_ir/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def output_names(self):
319319
def attribute(self):
320320
return self.original_node_proto.attribute
321321

322+
def set_version_if_custom_op(self, version_map: dict[str, int]) -> None:
323+
if self.domain != "" and self.domain in version_map:
324+
self.version = version_map[self.domain]
325+
322326
def get_attribute(self, name: str) -> int | float | None:
323327
return self.attributes.get(name, None)
324328

onnxscript/_legacy_ir/irbuilder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model:
3535
self._function_shape_env = visitor.FunctionShapeEnv()
3636
self._function_shape_env.load_from_model_proto(model_proto)
3737
self._ir_version = model_proto.ir_version
38-
version_map = {x.domain: x.version for x in model_proto.opset_import}
38+
self.version_map = {x.domain: x.version for x in model_proto.opset_import}
3939
functions = [self.visit_function(function) for function in model_proto.functions]
4040
self.functions = {function.id: function for function in functions}
4141
graph = self.visit_graph(model_proto.graph)
4242
model = ir.Model()
43-
model.set(model_proto, graph, functions, version_map)
43+
model.set(model_proto, graph, functions, self.version_map)
4444
return model
4545

4646
def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph:
@@ -122,6 +122,7 @@ def process_initializer(self, init: onnx.TensorProto):
122122

123123
def process_node(self, node):
124124
node_ir = ir.Node(node)
125+
node_ir.set_version_if_custom_op(self.version_map)
125126
self.current_graph_or_function.nodes.append(node_ir)
126127
for name in node.input:
127128
value = self.lookup(name)

onnxscript/_legacy_ir/protobuilder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def visit_ir_function(
7070
# function_proto.metadata_props = ir_function.original_function_proto.metadata_props)
7171

7272
for node in ir_function.nodes:
73+
# TODO: deduplicate the opset import of function?
7374
operator_setid_proto = function_proto.opset_import.add()
7475
if node.domain in self.opset_imports:
7576
operator_setid_proto.domain = self.opset_imports[node.domain].domain

onnxscript/rewriter/broadcast_to_matmul_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest
3636
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
3737
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
3838
{
39-
output = afunction (input_x, input_y)
39+
output = pkg.custom.afunction (input_x, input_y)
4040
}
4141
<domain: "pkg.custom", opset_import: [ "" : 17]>
4242
afunction (input_x, input_y) => (output)

onnxscript/rewriter/onnxruntime/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import onnx
44

55
from onnxscript._legacy_ir import irbuilder, protobuilder
6-
from onnxscript.optimizer import remove_unused
6+
from onnxscript.optimizer import remove_unused, remove_unused_function
77
from onnxscript.rewriter import function_rule, pattern
88
from onnxscript.rewriter.onnxruntime import (
9+
group_normalization_merge_silu,
910
instance_to_group_normalization,
1011
softmax,
1112
transformers,
@@ -16,6 +17,8 @@
1617
ORT_PATTERN_REWRITE_RULES = [
1718
*softmax.rules.rules,
1819
*instance_to_group_normalization.rules.rules,
20+
# NOTE: group normalization merge silu should be applied after instance to group normalization
21+
*group_normalization_merge_silu.rules.rules,
1922
]
2023

2124

@@ -49,5 +52,8 @@ def rewrite(
4952
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir)
5053
print(f"Applied {count} pattern rewrite rules.")
5154
model = protobuilder.build_model_proto(model_ir)
55+
# TODO: Does it make more sense we run DCE after each rewrite rule applied?
56+
# If so, we need IR to support DCE.
5257
remove_unused.remove_unused_nodes(model)
58+
remove_unused_function.remove_unused_functions(model)
5359
return model
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from onnxscript.rewriter import pattern
6+
7+
op = pattern.onnxop
8+
msft_op = pattern.msft_op
9+
torch_module_op = pattern.torch_module_op
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def group_normalization_and_silu_submodule(
15+
input,
16+
weight,
17+
bias,
18+
epsilon,
19+
groups,
20+
):
21+
group_norm = msft_op.GroupNorm(
22+
input,
23+
weight,
24+
bias,
25+
activation=0,
26+
channels_last=1,
27+
epsilon=epsilon,
28+
groups=groups,
29+
)
30+
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
31+
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed)
32+
33+
34+
def group_normalization_with_silu(
35+
input,
36+
weight,
37+
bias,
38+
epsilon,
39+
groups,
40+
):
41+
group_norm = msft_op.GroupNorm(
42+
input,
43+
weight,
44+
bias,
45+
activation=1,
46+
channels_last=1,
47+
epsilon=epsilon,
48+
groups=groups,
49+
)
50+
return op.Transpose(group_norm, perm=[0, 3, 1, 2])
51+
52+
53+
group_normalization_merge_silu_submodule_rule = pattern.RewriteRule(
54+
group_normalization_and_silu_submodule,
55+
group_normalization_with_silu,
56+
)
57+
58+
rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule])
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import unittest
2+
3+
import numpy as np
4+
import onnx.parser
5+
6+
from onnxscript._legacy_ir import irbuilder
7+
from onnxscript.rewriter.onnxruntime import (
8+
group_normalization_merge_silu,
9+
instance_to_group_normalization,
10+
)
11+
12+
13+
class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase):
14+
def test_group_norm_with_silu_submodule_is_replaced_by_group_norm(self):
15+
model = onnx.parser.parse_model(
16+
"""
17+
<ir_version: 7, opset_import: ["" : 17, "pkg.torch230a0git77ef9d4" : 1, "com.microsoft" : 1]>
18+
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
19+
{
20+
group_norm = com.microsoft.GroupNorm <activation=0, channels_last=1, epsilon=0.000001, groups=32>(image, weight, bias)
21+
transposed = Transpose <perm=[0, 3, 1, 2]>(group_norm)
22+
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed)
23+
}
24+
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
25+
torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) => (output)
26+
{
27+
_to_copy_38 = Cast <to: int = 1> (transposed)
28+
sigmoid_18 = Sigmoid (_to_copy_38)
29+
mul_26 = Mul (_to_copy_38, sigmoid_18)
30+
output = Cast <to: int = 10> (mul_26)
31+
}
32+
"""
33+
)
34+
# Use inserted initializers to avoid manually coding the large constants
35+
weight_value = np.random.rand(320, 1, 1).astype(np.float16)
36+
bias_value = np.random.rand(320, 1, 1).astype(np.float16)
37+
model.graph.initializer.extend(
38+
[
39+
onnx.helper.make_tensor(
40+
"weight",
41+
onnx.TensorProto.FLOAT16,
42+
weight_value.shape,
43+
weight_value,
44+
),
45+
onnx.helper.make_tensor(
46+
"bias",
47+
onnx.TensorProto.FLOAT16,
48+
bias_value.shape,
49+
bias_value,
50+
),
51+
]
52+
)
53+
54+
ir = irbuilder.build_ir(model)
55+
count = group_normalization_merge_silu.rules.apply_to_model(ir)
56+
self.assertEqual(count, 1)
57+
# plus 2 in model constants
58+
self.assertEqual(len(ir.graph.nodes), 2)
59+
60+
def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self):
61+
model = onnx.parser.parse_model(
62+
"""
63+
<ir_version: 7, opset_import: [ "" : 17, "pkg.torch230a0git77ef9d4" : 1]>
64+
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
65+
{
66+
adjusted_input_shape = Constant<value: tensor = int64[3] {0, 32, -1}>()
67+
image_reshape = Reshape (image, adjusted_input_shape)
68+
instance_norm = InstanceNormalization <epsilon=0.000001>(image_reshape, weight_for_norm, bias_for_norm)
69+
original_input_shape = Constant<value: tensor = int64[4] {1, 320, 128, 128}>()
70+
instance_norm_reshape = Reshape (instance_norm, original_input_shape)
71+
mul_output = Mul (instance_norm_reshape, weight_full)
72+
add_output = Add (mul_output, bias_full)
73+
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output)
74+
}
75+
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
76+
torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) => (output)
77+
{
78+
_to_copy_38 = Cast <to: int = 1> (add_output)
79+
sigmoid_18 = Sigmoid (_to_copy_38)
80+
mul_26 = Mul (_to_copy_38, sigmoid_18)
81+
output = Cast <to: int = 10> (mul_26)
82+
}
83+
"""
84+
)
85+
# Use inserted initializers to avoid manually coding the large constants
86+
weight_full_value = np.random.rand(320, 1, 1).astype(np.float16)
87+
bias_full_value = np.random.rand(320, 1, 1).astype(np.float16)
88+
weight_for_norm_value = np.ones(32, dtype=np.float16)
89+
bias_for_norm_value = np.zeros(32, dtype=np.float16)
90+
91+
model.graph.initializer.extend(
92+
[
93+
onnx.helper.make_tensor(
94+
"weight_for_norm",
95+
onnx.TensorProto.FLOAT16,
96+
weight_for_norm_value.shape,
97+
weight_for_norm_value,
98+
),
99+
onnx.helper.make_tensor(
100+
"bias_for_norm",
101+
onnx.TensorProto.FLOAT16,
102+
bias_for_norm_value.shape,
103+
bias_for_norm_value,
104+
),
105+
onnx.helper.make_tensor(
106+
"weight_full",
107+
onnx.TensorProto.FLOAT16,
108+
weight_full_value.shape,
109+
weight_full_value,
110+
),
111+
onnx.helper.make_tensor(
112+
"bias_full",
113+
onnx.TensorProto.FLOAT16,
114+
bias_full_value.shape,
115+
bias_full_value,
116+
),
117+
]
118+
)
119+
120+
ir = irbuilder.build_ir(model)
121+
count = instance_to_group_normalization.rules.apply_to_model(ir)
122+
count += group_normalization_merge_silu.rules.apply_to_model(ir)
123+
self.assertEqual(count, 2)
124+
# plus 2 in model constants
125+
self.assertEqual(len(ir.graph.nodes), 10)

onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
op = pattern.onnxop
1313
msft_op = pattern.msft_op
14+
torch_module_op = pattern.torch_module_op
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -146,4 +147,6 @@ def group_normalization(
146147
check_if_simulated_instance_norm_is_used,
147148
)
148149

150+
# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,
151+
# so we need to run instance_norm_to_group_norm_with_silu_rule first.
149152
rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule])

0 commit comments

Comments
 (0)