18
18
19
19
20
20
def fill_pads_with_axes (pads : Sequence [int ], axes : Sequence [int ], rank : int ) -> List [int ]:
21
+ """Converts the parameters of the ONNX Pad operator into an explicit list of values.
22
+
23
+ A filled list of pads will be returned following the format:
24
+ [x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25
+
26
+ Args:
27
+ pads: list of integers indicating the number of padding elements to add at
28
+ the beginning and end of each axis.
29
+ axes: list of axes that pads apply to.
30
+ rank: value to compute the size of the filled list (2 * rank).
31
+
32
+ Returns:
33
+ The filled list of pads.
34
+ """
21
35
new_pads = []
22
36
for axis in range (rank ):
23
37
if axis not in axes :
@@ -47,11 +61,13 @@ def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
47
61
return attributes
48
62
49
63
50
- class _FusePadConvBase (orp .RewriteRuleClassBase ):
64
+ class _FuseConvPadBase (orp .RewriteRuleClassBase ):
51
65
"""Interface for PadConv nodes fusion."""
52
66
53
67
def __init__ (self , as_function : bool = False ):
54
- # Remove nodes is set to False to remove unused nodes after the rewrite.
68
+ # Remove nodes is set to False to remove unused nodes after the rewrite, since
69
+ # Pad or Conv inputs can come from constant nodes.
70
+ # With remove_nodes=False these nodes are removed if these nodes are no longer needed.
55
71
super ().__init__ (remove_nodes = False , as_function = as_function )
56
72
57
73
def rewrite (
@@ -89,6 +105,22 @@ def rewrite(
89
105
)
90
106
91
107
def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
108
+ """Condition to check if we need to replace the pattern.
109
+
110
+ If Pad inputs can be added in 'pads' attribute of the Conv operator.
111
+
112
+ To validate this, we need to check the following:
113
+ 1. `Pad<mode>` attribute has 'constant' as value
114
+ 2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
115
+ 3. 'constant_value' is equal to 0.0.
116
+ 4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
117
+ remain unchanged).
118
+
119
+ If the above are true, then we don't need the reshapes.
120
+
121
+ Returns:
122
+ True if we need to replace the pattern, False otherwise.
123
+ """
92
124
del context # Unused
93
125
check_result = orp .MatchResult ()
94
126
pad_node = pad .producer ()
@@ -123,8 +155,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
123
155
return check_result
124
156
125
157
126
- class FusePadConv ( _FusePadConvBase ):
127
- """Replaces ``Pad( Conv(x))`` with ``Conv(x)``."""
158
+ class FuseConvPad ( _FuseConvPadBase ):
159
+ """Replaces ``Conv(Pad (x))`` with ``Conv(x)``."""
128
160
129
161
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
130
162
return op .Conv (
@@ -147,8 +179,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
147
179
return check_result
148
180
149
181
150
- class FusePadConvInteger ( FusePadConv ):
151
- """Replaces ``Pad( ConvInteger(x))`` with ``ConvInteger(x)``."""
182
+ class FuseConvIntegerPad ( FuseConvPad ):
183
+ """Replaces ``ConvInteger(Pad (x))`` with ``ConvInteger(x)``."""
152
184
153
185
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
154
186
return op .ConvInteger (
@@ -195,6 +227,24 @@ def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
195
227
)
196
228
197
229
def check (self , context , conv : ir .Value , ** __ ) -> orp .MatchResult :
230
+ """Condition to check if we need to replace the pattern.
231
+
232
+ If it is possible to deduce 'pads'.
233
+
234
+ To validate this, we need to check the following:
235
+ 1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
236
+ already explicit)
237
+ 2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
238
+ 3. When `Conv<auto_pad != "VALID">`:
239
+ * spatial input/output shapes are static
240
+ * it is possible to infer `kernel_shape` either from the `Conv` operator attribute
241
+ or from the kernel input
242
+
243
+ If the above are true, then we don't need the reshapes.
244
+
245
+ Returns:
246
+ True if we need to replace the pattern, False otherwise.
247
+ """
198
248
del context
199
249
check_result = orp .MatchResult ()
200
250
@@ -274,8 +324,8 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
274
324
275
325
normalize_pad_format_conv = NormalizePadFormatConv .rule ()
276
326
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger .rule ()
277
- fuse_pad_into_conv = FusePadConv .rule ()
278
- fuse_pad_into_conv_integer = FusePadConvInteger .rule ()
327
+ fuse_pad_into_conv = FuseConvPad .rule ()
328
+ fuse_pad_into_conv_integer = FuseConvIntegerPad .rule ()
279
329
280
330
281
331
def fuse_pad_into_conv_rule_set () -> orp .RewriteRuleSet :
0 commit comments