5
5
import numpy as np
6
6
import onnx
7
7
8
+ from onnxscript import ir
8
9
from onnxscript .rewriter import _ir_utils , pattern
9
10
10
11
torch_module_op = pattern .torch_module_op
11
12
12
13
logger = logging .getLogger (__name__ )
13
14
14
15
15
- def check_if_simulated_instance_norm_is_used (
16
+ def _simulated_instance_norm (
16
17
context ,
17
- input_x ,
18
- adjusted_input_shape ,
19
- original_input_shape ,
20
- weight_for_norm ,
21
- bias_for_norm ,
22
- weight_full ,
23
- bias_full ,
18
+ input_x : ir . Value ,
19
+ adjusted_input_shape : ir . Value ,
20
+ original_input_shape : ir . Value ,
21
+ weight_for_norm : ir . Value ,
22
+ bias_for_norm : ir . Value ,
23
+ weight_full : ir . Value ,
24
+ bias_full : ir . Value ,
24
25
** _ ,
25
26
) -> bool :
26
27
"""Check if the simulated instance normalization is used.
@@ -38,16 +39,16 @@ def check_if_simulated_instance_norm_is_used(
38
39
6. original_input_shape is the same as input_x shape.
39
40
40
41
Returns:
41
- bool: True if the simulated instance normalization is used, False otherwise.
42
+ True if the simulated instance normalization is used, False otherwise.
42
43
"""
43
- weight_for_norm_prop = _ir_utils .propagate_const_value (weight_for_norm )
44
- weight_for_norm_const_value = weight_for_norm_prop .const_value
44
+ _ir_utils .propagate_const_value (weight_for_norm )
45
+ weight_for_norm_const_value = weight_for_norm .const_value
45
46
if weight_for_norm_const_value is None :
46
47
return False
47
48
weight_for_norm = weight_for_norm_const_value .numpy ()
48
49
49
- bias_for_norm_prop = _ir_utils .propagate_const_value (bias_for_norm )
50
- bias_for_norm_const_value = bias_for_norm_prop .const_value
50
+ _ir_utils .propagate_const_value (bias_for_norm )
51
+ bias_for_norm_const_value = bias_for_norm .const_value
51
52
if bias_for_norm_const_value is None :
52
53
return False
53
54
bias_for_norm = bias_for_norm_const_value .numpy ()
@@ -57,7 +58,7 @@ def check_if_simulated_instance_norm_is_used(
57
58
if not np .all (bias_for_norm == 0 ):
58
59
return False
59
60
60
- input_rank_minus_one = len ( input_x .shape ) - 1
61
+ input_rank_minus_one = input_x .shape . rank ( ) - 1
61
62
weight_full_rank = len (weight_full .shape )
62
63
bias_full_rank = len (bias_full .shape )
63
64
if weight_full_rank != input_rank_minus_one or bias_full_rank != input_rank_minus_one :
@@ -74,7 +75,7 @@ def check_if_simulated_instance_norm_is_used(
74
75
if not all (dim == 1 for dim in bias_full_shape [1 :]):
75
76
return False
76
77
77
- adjusted_input_shape = _ir_utils .propagate_const_value (adjusted_input_shape )
78
+ _ir_utils .propagate_const_value (adjusted_input_shape )
78
79
adjusted_input_shape_const_value = adjusted_input_shape .const_value
79
80
80
81
g = weight_for_norm .shape [0 ]
@@ -85,7 +86,7 @@ def check_if_simulated_instance_norm_is_used(
85
86
return False
86
87
87
88
# NOTE: Restrict the rule to only support constant shape
88
- original_input_shape = _ir_utils .propagate_const_value (original_input_shape )
89
+ _ir_utils .propagate_const_value (original_input_shape )
89
90
original_input_shape_const_value = original_input_shape .const_value
90
91
if (
91
92
original_input_shape_const_value is None
@@ -149,7 +150,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep
149
150
instance_norm_to_group_norm_rule = pattern .RewriteRule (
150
151
instance_simulates_group_normalization_pattern ,
151
152
group_normalization ,
152
- check_if_simulated_instance_norm_is_used ,
153
+ _simulated_instance_norm ,
153
154
)
154
155
155
156
# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,
0 commit comments