Skip to content

Commit 5156668

Browse files
committed
snap
1 parent feea637 commit 5156668

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@
55
import numpy as np
66
import onnx
77

8+
from onnxscript import ir
89
from onnxscript.rewriter import _ir_utils, pattern
910

1011
torch_module_op = pattern.torch_module_op
1112

1213
logger = logging.getLogger(__name__)
1314

1415

15-
def check_if_simulated_instance_norm_is_used(
16+
def _simulated_instance_norm(
1617
context,
17-
input_x,
18-
adjusted_input_shape,
19-
original_input_shape,
20-
weight_for_norm,
21-
bias_for_norm,
22-
weight_full,
23-
bias_full,
18+
input_x: ir.Value,
19+
adjusted_input_shape: ir.Value,
20+
original_input_shape: ir.Value,
21+
weight_for_norm: ir.Value,
22+
bias_for_norm: ir.Value,
23+
weight_full: ir.Value,
24+
bias_full: ir.Value,
2425
**_,
2526
) -> bool:
2627
"""Check if the simulated instance normalization is used.
@@ -38,16 +39,16 @@ def check_if_simulated_instance_norm_is_used(
3839
6. original_input_shape is the same as input_x shape.
3940
4041
Returns:
41-
bool: True if the simulated instance normalization is used, False otherwise.
42+
True if the simulated instance normalization is used, False otherwise.
4243
"""
43-
weight_for_norm_prop = _ir_utils.propagate_const_value(weight_for_norm)
44-
weight_for_norm_const_value = weight_for_norm_prop.const_value
44+
_ir_utils.propagate_const_value(weight_for_norm)
45+
weight_for_norm_const_value = weight_for_norm.const_value
4546
if weight_for_norm_const_value is None:
4647
return False
4748
weight_for_norm = weight_for_norm_const_value.numpy()
4849

49-
bias_for_norm_prop = _ir_utils.propagate_const_value(bias_for_norm)
50-
bias_for_norm_const_value = bias_for_norm_prop.const_value
50+
_ir_utils.propagate_const_value(bias_for_norm)
51+
bias_for_norm_const_value = bias_for_norm.const_value
5152
if bias_for_norm_const_value is None:
5253
return False
5354
bias_for_norm = bias_for_norm_const_value.numpy()
@@ -57,7 +58,7 @@ def check_if_simulated_instance_norm_is_used(
5758
if not np.all(bias_for_norm == 0):
5859
return False
5960

60-
input_rank_minus_one = len(input_x.shape) - 1
61+
input_rank_minus_one = input_x.shape.rank() - 1
6162
weight_full_rank = len(weight_full.shape)
6263
bias_full_rank = len(bias_full.shape)
6364
if weight_full_rank != input_rank_minus_one or bias_full_rank != input_rank_minus_one:
@@ -74,7 +75,7 @@ def check_if_simulated_instance_norm_is_used(
7475
if not all(dim == 1 for dim in bias_full_shape[1:]):
7576
return False
7677

77-
adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape)
78+
_ir_utils.propagate_const_value(adjusted_input_shape)
7879
adjusted_input_shape_const_value = adjusted_input_shape.const_value
7980

8081
g = weight_for_norm.shape[0]
@@ -85,7 +86,7 @@ def check_if_simulated_instance_norm_is_used(
8586
return False
8687

8788
# NOTE: Restrict the rule to only support constant shape
88-
original_input_shape = _ir_utils.propagate_const_value(original_input_shape)
89+
_ir_utils.propagate_const_value(original_input_shape)
8990
original_input_shape_const_value = original_input_shape.const_value
9091
if (
9192
original_input_shape_const_value is None
@@ -149,7 +150,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep
149150
instance_norm_to_group_norm_rule = pattern.RewriteRule(
150151
instance_simulates_group_normalization_pattern,
151152
group_normalization,
152-
check_if_simulated_instance_norm_is_used,
153+
_simulated_instance_norm,
153154
)
154155

155156
# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,

0 commit comments

Comments
 (0)