18
18
19
19
20
20
def fill_pads_with_axes (pads : Sequence [int ], axes : Sequence [int ], rank : int ) -> List [int ]:
21
+ """Converts the parameters of the ONNX Pad operator into an explicit list of values.
22
+
23
+ A filled list of pads will be returned following the format:
24
+ [x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25
+
26
+ Args:
27
+ pads: list of integers indicating the number of padding elements to add at
28
+ the beginning and end of each axis.
29
+ axes: list of axes that pads apply to.
30
+ rank: value to compute the size of the filled list (2 * rank).
31
+
32
+ Returns:
33
+ The filled list of pads.
34
+ """
21
35
new_pads = [0 ] * 2 * rank
22
36
N = len (axes )
23
37
for start_idx , axis in enumerate (axes ):
@@ -42,11 +56,13 @@ def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
42
56
return attributes
43
57
44
58
45
- class _FusePadConvBase (orp .RewriteRuleClassBase ):
59
+ class _FuseConvPadBase (orp .RewriteRuleClassBase ):
46
60
"""Interface for PadConv nodes fusion."""
47
61
48
62
def __init__ (self , as_function : bool = False ):
49
- # Remove nodes is set to False to remove unused nodes after the rewrite.
63
+ # Remove nodes is set to False to remove unused nodes after the rewrite, since
64
+ # Pad or Conv inputs can come from constant nodes.
65
+ # With remove_nodes=False these nodes are removed if these nodes are no longer needed.
50
66
super ().__init__ (remove_nodes = False , as_function = as_function )
51
67
52
68
def rewrite (
@@ -84,14 +100,32 @@ def rewrite(
84
100
)
85
101
86
102
def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
103
+ """Condition to check if we need to replace the pattern.
104
+
105
+ If Pad inputs can be added in 'pads' attribute of the Conv operator.
106
+
107
+ To validate this, we need to check the following:
108
+ 1. `Pad<mode>` attribute has 'constant' as value
109
+ 2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
110
+ 3. 'constant_value' is equal to 0.0.
111
+ 4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
112
+ remain unchanged).
113
+
114
+ If the above are true, then we don't need the reshapes.
115
+
116
+ Returns:
117
+ True if we need to replace the pattern, False otherwise.
118
+ """
87
119
del context # Unused
88
120
check_result = orp .MatchResult ()
89
121
pad_node = pad .producer ()
90
122
x_rank = len (x .shape )
91
123
92
124
# Pad constraints: attributes
93
125
if (mode := pad_node .attributes .get ("mode" , None )) and mode .as_string () != "constant" :
94
- return check_result .fail (f"{ pad_node .name } mode must be 'constant'." )
126
+ return check_result .fail (
127
+ f"{ pad_node .name } ({ pad_node .op_type } ) mode must be 'constant'."
128
+ )
95
129
96
130
# Pad constraints: inputs
97
131
if (pads := pad_node .inputs [1 ]).const_value is None :
@@ -118,8 +152,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
118
152
return check_result
119
153
120
154
121
- class FusePadConv ( _FusePadConvBase ):
122
- """Replaces ``Pad( Conv(x))`` with ``Conv(x)``."""
155
+ class FuseConvPad ( _FuseConvPadBase ):
156
+ """Replaces ``Conv(Pad (x))`` with ``Conv(x)``."""
123
157
124
158
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
125
159
return op .Conv (
@@ -138,12 +172,14 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
138
172
if (
139
173
apad := conv_node .attributes .get ("auto_pad" , None )
140
174
) and apad .as_string () != "NOTSET" :
141
- return check_result .fail (f"{ conv_node .name } auto_pad must be 'NOTSET'." )
175
+ return check_result .fail (
176
+ f"{ conv_node .name } ({ conv_node .op_type } ) auto_pad must be 'NOTSET'."
177
+ )
142
178
return check_result
143
179
144
180
145
- class FusePadConvInteger ( FusePadConv ):
146
- """Replaces ``Pad( ConvInteger(x))`` with ``ConvInteger(x)``."""
181
+ class FuseConvIntegerPad ( FuseConvPad ):
182
+ """Replaces ``ConvInteger(Pad (x))`` with ``ConvInteger(x)``."""
147
183
148
184
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
149
185
return op .ConvInteger (
@@ -190,36 +226,63 @@ def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
190
226
)
191
227
192
228
def check (self , context , conv : ir .Value , ** __ ) -> orp .MatchResult :
229
+ """Condition to check if we need to replace the pattern.
230
+
231
+ If it is possible to deduce 'pads'.
232
+
233
+ To validate this, we need to check the following:
234
+ 1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
235
+ already explicit)
236
+ 2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
237
+ 3. When `Conv<auto_pad != "VALID">`:
238
+ * spatial input/output shapes are static
239
+ * it is possible to infer `kernel_shape` either from the `Conv` operator attribute
240
+ or from the kernel input
241
+
242
+ If the above are true, then we don't need the reshapes.
243
+
244
+ Returns:
245
+ True if we need to replace the pattern, False otherwise.
246
+ """
193
247
del context
194
248
check_result = orp .MatchResult ()
195
249
196
250
# Conv constraints: attributes
197
251
conv_node = conv .producer ()
198
252
auto_pad = conv_node .attributes .get_string ("auto_pad" , None )
199
- if auto_pad in [ None , "NOTSET" ] :
253
+ if auto_pad in { None , "NOTSET" } :
200
254
return check_result .fail (
201
- f"{ conv_node .name } auto_pad must be different to 'NOTSET'."
255
+ f"{ conv_node .name } ( { conv_node . op_type } ) auto_pad must be different to 'NOTSET'."
202
256
)
203
257
204
258
# Conv constraints: inputs/outputs
205
259
input_shape = conv_node .inputs [0 ].shape
206
260
output_shape = conv_node .outputs [0 ].shape
207
261
if len (input_shape ) <= 2 :
208
- return check_result .fail (f"Input shapes are not defined on { conv_node .name } ." )
262
+ return check_result .fail (
263
+ f"Input shapes are not defined on { conv_node .name } ({ conv_node .op_type } )."
264
+ )
209
265
if len (output_shape ) <= 2 :
210
- return check_result .fail (f"Output shapes are not defined on { conv_node .name } ." )
266
+ return check_result .fail (
267
+ f"Output shapes are not defined on { conv_node .name } ({ conv_node .op_type } )."
268
+ )
211
269
212
270
# Conv constraints: values
213
271
if auto_pad != "VALID" :
214
- error_msg = "Expected static spatial {} shapes on " + conv_node .name + "."
272
+ error_msg = (
273
+ "Expected static spatial {} shapes on "
274
+ + conv_node .name
275
+ + f" ({ conv_node .op_type } )."
276
+ )
215
277
if not all (isinstance (x , int ) for x in input_shape [2 :]):
216
278
return check_result .fail (error_msg .format ("input" ))
217
279
if not all (isinstance (x , int ) for x in output_shape [2 :]):
218
280
return check_result .fail (error_msg .format ("output" ))
219
281
attributes = read_conv_attributes (conv_node )
220
282
if len (attributes ["kernel_shape" ]) != len (attributes ["strides" ]):
221
283
return check_result .fail (
222
- f"strides must have the same length than kernel_shape on { conv_node .name } ."
284
+ "strides must have the same length than kernel_shape on "
285
+ f"{ conv_node .name } ({ conv_node .op_type } )."
223
286
)
224
287
return check_result
225
288
@@ -234,7 +297,7 @@ def compute_pads(
234
297
attributes : dict [str , Sequence [int ] | str ],
235
298
) -> Sequence [int ]:
236
299
# Compute pads, following auto_pad/pads attributes
237
- if attributes ["auto_pad" ] in [ "NOTSET" , "VALID" ] :
300
+ if attributes ["auto_pad" ] in { "NOTSET" , "VALID" } :
238
301
assert len (input_shape ) > 0
239
302
return attributes .get ("pads" , [0 ] * len (input_shape ) * 2 )
240
303
@@ -269,8 +332,8 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
269
332
270
333
normalize_pad_format_conv = NormalizePadFormatConv .rule ()
271
334
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger .rule ()
272
- fuse_pad_into_conv = FusePadConv .rule ()
273
- fuse_pad_into_conv_integer = FusePadConvInteger .rule ()
335
+ fuse_pad_into_conv = FuseConvPad .rule ()
336
+ fuse_pad_into_conv_integer = FuseConvIntegerPad .rule ()
274
337
275
338
276
339
def fuse_pad_into_conv_rule_set () -> orp .RewriteRuleSet :
0 commit comments