Skip to content

Commit 7f7d17e

Browse files
committed
[Rewriter] improve message and code (#2301)
1 parent 0098323 commit 7f7d17e

File tree

2 files changed

+84
-21
lines changed

2 files changed

+84
-21
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@
1818

1919

2020
def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]:
21+
"""Converts the parameters of the ONNX Pad operator into an explicit list of values.
22+
23+
A filled list of pads will be returned following the format:
24+
[x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25+
26+
Args:
27+
pads: list of integers indicating the number of padding elements to add at
28+
the beginning and end of each axis.
29+
axes: list of axes that pads apply to.
30+
rank: value to compute the size of the filled list (2 * rank).
31+
32+
Returns:
33+
The filled list of pads.
34+
"""
2135
new_pads = [0] * 2 * rank
2236
N = len(axes)
2337
for start_idx, axis in enumerate(axes):
@@ -42,11 +56,13 @@ def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
4256
return attributes
4357

4458

45-
class _FusePadConvBase(orp.RewriteRuleClassBase):
59+
class _FuseConvPadBase(orp.RewriteRuleClassBase):
4660
"""Interface for PadConv nodes fusion."""
4761

4862
def __init__(self, as_function: bool = False):
49-
# Remove nodes is set to False to remove unused nodes after the rewrite.
63+
# Remove nodes is set to False to remove unused nodes after the rewrite, since
64+
# Pad or Conv inputs can come from constant nodes.
65+
# With remove_nodes=False these nodes are removed if these nodes are no longer needed.
5066
super().__init__(remove_nodes=False, as_function=as_function)
5167

5268
def rewrite(
@@ -84,14 +100,32 @@ def rewrite(
84100
)
85101

86102
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
103+
"""Condition to check if we need to replace the pattern.
104+
105+
If Pad inputs can be added in 'pads' attribute of the Conv operator.
106+
107+
To validate this, we need to check the following:
108+
1. `Pad<mode>` attribute has 'constant' as value
109+
2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
110+
3. 'constant_value' is equal to 0.0.
111+
4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
112+
remain unchanged).
113+
114+
If the above are true, then we don't need the reshapes.
115+
116+
Returns:
117+
True if we need to replace the pattern, False otherwise.
118+
"""
87119
del context # Unused
88120
check_result = orp.MatchResult()
89121
pad_node = pad.producer()
90122
x_rank = len(x.shape)
91123

92124
# Pad constraints: attributes
93125
if (mode := pad_node.attributes.get("mode", None)) and mode.as_string() != "constant":
94-
return check_result.fail(f"{pad_node.name} mode must be 'constant'.")
126+
return check_result.fail(
127+
f"{pad_node.name} ({pad_node.op_type}) mode must be 'constant'."
128+
)
95129

96130
# Pad constraints: inputs
97131
if (pads := pad_node.inputs[1]).const_value is None:
@@ -118,8 +152,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
118152
return check_result
119153

120154

121-
class FusePadConv(_FusePadConvBase):
122-
"""Replaces ``Pad(Conv(x))`` with ``Conv(x)``."""
155+
class FuseConvPad(_FuseConvPadBase):
156+
"""Replaces ``Conv(Pad(x))`` with ``Conv(x)``."""
123157

124158
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
125159
return op.Conv(
@@ -138,12 +172,14 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
138172
if (
139173
apad := conv_node.attributes.get("auto_pad", None)
140174
) and apad.as_string() != "NOTSET":
141-
return check_result.fail(f"{conv_node.name} auto_pad must be 'NOTSET'.")
175+
return check_result.fail(
176+
f"{conv_node.name} ({conv_node.op_type}) auto_pad must be 'NOTSET'."
177+
)
142178
return check_result
143179

144180

145-
class FusePadConvInteger(FusePadConv):
146-
"""Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``."""
181+
class FuseConvIntegerPad(FuseConvPad):
182+
"""Replaces ``ConvInteger(Pad(x))`` with ``ConvInteger(x)``."""
147183

148184
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
149185
return op.ConvInteger(
@@ -190,36 +226,63 @@ def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
190226
)
191227

192228
def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
229+
"""Condition to check if we need to replace the pattern.
230+
231+
If it is possible to deduce 'pads'.
232+
233+
To validate this, we need to check the following:
234+
1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
235+
already explicit)
236+
2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
237+
3. When `Conv<auto_pad != "VALID">`:
238+
* spatial input/output shapes are static
239+
* it is possible to infer `kernel_shape` either from the `Conv` operator attribute
240+
or from the kernel input
241+
242+
If the above are true, then we don't need the reshapes.
243+
244+
Returns:
245+
True if we need to replace the pattern, False otherwise.
246+
"""
193247
del context
194248
check_result = orp.MatchResult()
195249

196250
# Conv constraints: attributes
197251
conv_node = conv.producer()
198252
auto_pad = conv_node.attributes.get_string("auto_pad", None)
199-
if auto_pad in [None, "NOTSET"]:
253+
if auto_pad in {None, "NOTSET"}:
200254
return check_result.fail(
201-
f"{conv_node.name} auto_pad must be different to 'NOTSET'."
255+
f"{conv_node.name} ({conv_node.op_type}) auto_pad must be different to 'NOTSET'."
202256
)
203257

204258
# Conv constraints: inputs/outputs
205259
input_shape = conv_node.inputs[0].shape
206260
output_shape = conv_node.outputs[0].shape
207261
if len(input_shape) <= 2:
208-
return check_result.fail(f"Input shapes are not defined on {conv_node.name}.")
262+
return check_result.fail(
263+
f"Input shapes are not defined on {conv_node.name} ({conv_node.op_type})."
264+
)
209265
if len(output_shape) <= 2:
210-
return check_result.fail(f"Output shapes are not defined on {conv_node.name}.")
266+
return check_result.fail(
267+
f"Output shapes are not defined on {conv_node.name} ({conv_node.op_type})."
268+
)
211269

212270
# Conv constraints: values
213271
if auto_pad != "VALID":
214-
error_msg = "Expected static spatial {} shapes on " + conv_node.name + "."
272+
error_msg = (
273+
"Expected static spatial {} shapes on "
274+
+ conv_node.name
275+
+ f" ({conv_node.op_type})."
276+
)
215277
if not all(isinstance(x, int) for x in input_shape[2:]):
216278
return check_result.fail(error_msg.format("input"))
217279
if not all(isinstance(x, int) for x in output_shape[2:]):
218280
return check_result.fail(error_msg.format("output"))
219281
attributes = read_conv_attributes(conv_node)
220282
if len(attributes["kernel_shape"]) != len(attributes["strides"]):
221283
return check_result.fail(
222-
f"strides must have the same length than kernel_shape on {conv_node.name}."
284+
"strides must have the same length than kernel_shape on "
285+
f"{conv_node.name} ({conv_node.op_type})."
223286
)
224287
return check_result
225288

@@ -234,7 +297,7 @@ def compute_pads(
234297
attributes: dict[str, Sequence[int] | str],
235298
) -> Sequence[int]:
236299
# Compute pads, following auto_pad/pads attributes
237-
if attributes["auto_pad"] in ["NOTSET", "VALID"]:
300+
if attributes["auto_pad"] in {"NOTSET", "VALID"}:
238301
assert len(input_shape) > 0
239302
return attributes.get("pads", [0] * len(input_shape) * 2)
240303

@@ -269,8 +332,8 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
269332

270333
normalize_pad_format_conv = NormalizePadFormatConv.rule()
271334
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule()
272-
fuse_pad_into_conv = FusePadConv.rule()
273-
fuse_pad_into_conv_integer = FusePadConvInteger.rule()
335+
fuse_pad_into_conv = FuseConvPad.rule()
336+
fuse_pad_into_conv_integer = FuseConvIntegerPad.rule()
274337

275338

276339
def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _clone_model(model: ir.Model) -> ir.Model:
2323
return ir.from_proto(ir.to_proto(model))
2424

2525

26-
class FusePadConvBaseTest(unittest.TestCase):
26+
class FuseConvPadBaseTest(unittest.TestCase):
2727
@property
2828
def rng(self):
2929
return np.random.default_rng(20250522)
@@ -89,7 +89,7 @@ def build_model(
8989
return ir_model
9090

9191

92-
class FusePadConvTest(FusePadConvBaseTest):
92+
class FuseConvPadTest(FuseConvPadBaseTest):
9393
@parameterized.parameterized.expand(
9494
[
9595
(pad_pads, const_value, axes, conv_pads, conv_auto_pad)
@@ -218,7 +218,7 @@ def test_unsupported_fuse_pad_into_conv(
218218
self.assertRegex(tracer_match.match_result.reason, err_msg)
219219

220220

221-
class FusePadConvIntegerTest(FusePadConvBaseTest):
221+
class FuseConvIntegerPadTest(FuseConvPadBaseTest):
222222
def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None):
223223
w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W")
224224
if tape is not None:
@@ -267,7 +267,7 @@ def test_fuse_pad_into_conv_integer(
267267
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
268268

269269

270-
class NormalizePadFormatTest(FusePadConvBaseTest):
270+
class NormalizePadFormatTest(FuseConvPadBaseTest):
271271
def build_model(
272272
self,
273273
input_shape: ir.Shape,

0 commit comments

Comments
 (0)