Skip to content

Commit 0098323

Browse files
committed
[Rewriter] improve NormalizePadFormat test (#2301)
Fix silent bugs
1 parent 1c81147 commit 0098323

File tree

2 files changed

+120
-30
lines changed

2 files changed

+120
-30
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,32 @@ def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
195195

196196
# Conv constraints: attributes
197197
conv_node = conv.producer()
198-
auto_pad = conv_node.attributes.get("auto_pad", None)
199-
if auto_pad is None or auto_pad.as_string() == "NOTSET":
198+
auto_pad = conv_node.attributes.get_string("auto_pad", None)
199+
if auto_pad in [None, "NOTSET"]:
200200
return check_result.fail(
201201
f"{conv_node.name} auto_pad must be different to 'NOTSET'."
202202
)
203203

204204
# Conv constraints: inputs/outputs
205-
if conv_node.inputs[0].shape is None:
205+
input_shape = conv_node.inputs[0].shape
206+
output_shape = conv_node.outputs[0].shape
207+
if len(input_shape) <= 2:
206208
return check_result.fail(f"Input shapes are not defined on {conv_node.name}.")
207-
if conv_node.outputs[0].shape is None:
209+
if len(output_shape) <= 2:
208210
return check_result.fail(f"Output shapes are not defined on {conv_node.name}.")
211+
212+
# Conv constraints: values
213+
if auto_pad != "VALID":
214+
error_msg = "Expected static spatial {} shapes on " + conv_node.name + "."
215+
if not all(isinstance(x, int) for x in input_shape[2:]):
216+
return check_result.fail(error_msg.format("input"))
217+
if not all(isinstance(x, int) for x in output_shape[2:]):
218+
return check_result.fail(error_msg.format("output"))
219+
attributes = read_conv_attributes(conv_node)
220+
if len(attributes["kernel_shape"]) != len(attributes["strides"]):
221+
return check_result.fail(
222+
f"strides must have the same length than kernel_shape on {conv_node.name}."
223+
)
209224
return check_result
210225

211226

@@ -220,10 +235,12 @@ def compute_pads(
220235
) -> Sequence[int]:
221236
# Compute pads, following auto_pad/pads attributes
222237
if attributes["auto_pad"] in ["NOTSET", "VALID"]:
238+
assert len(input_shape) > 0
223239
return attributes.get("pads", [0] * len(input_shape) * 2)
224240

225241
bottom_pads, top_pads = [], []
226242
kernel_shape, strides = attributes["kernel_shape"], attributes["strides"]
243+
assert len(kernel_shape) == len(strides) == len(input_shape) == len(output_shape)
227244
for x, y, k, s in zip(input_shape, output_shape, kernel_shape, strides):
228245
# Compute the output shape and the total padding to apply
229246
total_pads = max(0, (y - 1) * s + k - x)

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -268,24 +268,73 @@ def test_fuse_pad_into_conv_integer(
268268

269269

270270
class NormalizePadFormatTest(FusePadConvBaseTest):
271+
def build_model(
272+
self,
273+
input_shape: ir.Shape,
274+
conv_inputs: Sequence[int],
275+
conv_attributes: Mapping[str, ir.Attr] | None = None,
276+
infer_shapes=True,
277+
) -> ir.Model:
278+
tape = ir.tape.Tape()
279+
inputs = []
280+
output_shape = ir.Shape(("?",) * len(input_shape))
281+
282+
# Convert conv_inputs to initializers (if needed)
283+
conv_inputs = list(conv_inputs)
284+
for idx, x in enumerate(conv_inputs):
285+
if isinstance(x, ir.TensorProtocol):
286+
conv_inputs[idx] = tape.initializer(x)
287+
elif isinstance(x, ir.Value):
288+
inputs.append(x)
289+
elif x is not None:
290+
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")
291+
292+
# Register operations in the tape
293+
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
294+
y = tape.op(
295+
"Conv",
296+
inputs=[x, *conv_inputs],
297+
attributes=conv_attributes,
298+
output=ir.Input("Y", shape=output_shape, type=x.type),
299+
)
300+
301+
# Build the model
302+
ir_model = ir.Model(
303+
ir.Graph(
304+
inputs=[x, *inputs],
305+
outputs=[y],
306+
nodes=tape.nodes,
307+
initializers=tape.initializers,
308+
opset_imports={"": 20},
309+
name="model",
310+
),
311+
ir_version=10,
312+
)
313+
if len(input_shape) > 0 and infer_shapes:
314+
onnx_checker.CheckerPass(True)(ir_model)
315+
ir_model = shape_inference.infer_shapes(ir_model)
316+
else:
317+
onnx_checker.CheckerPass(False)(ir_model)
318+
return ir_model
319+
271320
@parameterized.parameterized.expand(
272321
[
273-
(strides, kernel_shape, auto_pad)
322+
(dynamic_shape, strides, kernel_shape, auto_pad)
274323
for strides, kernel_shape in [((2, 3), (1, 4)), ((2, 1), (5, 2))]
275-
for auto_pad in ["SAME_UPPER", "SAME_LOWER", "VALID"]
324+
for dynamic_shape, auto_pad in [
325+
(False, "SAME_UPPER"),
326+
(False, "SAME_LOWER"),
327+
(True, "VALID"),
328+
]
276329
]
277330
)
278-
def test_normalize_pad_format(self, strides, kernel_shape, auto_pad):
279-
pad_inputs = [
280-
ir.tensor([1, 1, 1, 1], name="pads"),
281-
None,
282-
ir.tensor([2, 3], name="axes"),
283-
]
331+
def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_pad):
332+
input_shape = (
333+
ir.Shape(("N", "A", "B", "C")) if dynamic_shape else ir.Shape(("N", 32, 22, 27))
334+
)
284335
base_model = self.build_model(
285-
op_type="Conv",
286-
input_shape=ir.Shape(("N", 32, 22, 27)),
287-
weight_shape=(32, 32, *kernel_shape),
288-
pad_inputs=pad_inputs,
336+
input_shape=input_shape,
337+
conv_inputs=[ir.tensor(self.get_conv_weights((32, 32, *kernel_shape)), name="W")],
289338
conv_attributes={
290339
"strides": strides,
291340
"auto_pad": auto_pad,
@@ -296,27 +345,51 @@ def test_normalize_pad_format(self, strides, kernel_shape, auto_pad):
296345

297346
# Apply rule
298347
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
299-
300-
# Check that Pad was fused
301-
self.assertEqual(count, 2)
302-
self.assertEqual(updated_model.graph.num_nodes(), 1)
303348
onnx_checker.CheckerPass(True)(updated_model)
304349

350+
# Check conv has changed
351+
self.assertEqual(count, 1)
352+
self.assertEqual(updated_model.graph[0].attributes.get_string("auto_pad"), "NOTSET")
353+
305354
# Check inference
306355
inputs = self.rng.random((1, 32, 22, 27), dtype="float32")
307356
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
308357

309-
def test_unsupported_normalize_pad_format(self):
358+
@parameterized.parameterized.expand(
359+
[
360+
(ir.Shape([]), False, "Input shapes are not defined"),
361+
(ir.Shape(("N", "C", "A")), False, "Expected static spatial input shapes"),
362+
(ir.Shape(("N", "C", 32)), False, "Expected static spatial output shapes"),
363+
]
364+
)
365+
def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error_msg):
310366
base_model = self.build_model(
311-
op_type="Conv",
312-
input_shape=ir.Shape(("N", 32, 14)),
313-
weight_shape=(32, 11, 4),
314-
pad_inputs=[ir.tensor([0, 0, 0, 0, 0, 0], name="pads")],
315-
conv_attributes={"auto_pad": "VALID"},
367+
input_shape=input_shape,
368+
conv_inputs=[ir.tensor(np.ones((32, 11, 4)), name="W")],
369+
conv_attributes={"auto_pad": "SAME_UPPER"},
370+
infer_shapes=infer_shapes,
371+
)
372+
373+
# Apply rule and check it was not applied
374+
tracer = orp.MatchingTracer()
375+
count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer)
376+
self.assertEqual(count, 0)
377+
378+
# Check that the error message is the expected one
379+
tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0]
380+
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
381+
self.assertRegex(tracer_match.match_result.reason, error_msg)
382+
383+
def test_unsupported_normalize_pad_format_on_weights(self):
384+
W = ir.Value(name="W", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.FLOAT))
385+
base_model = self.build_model(
386+
input_shape=ir.Shape(("N", 2, 32)),
387+
conv_inputs=[W],
388+
conv_attributes={"auto_pad": "SAME_UPPER"},
389+
infer_shapes=False,
316390
)
317-
# Drop convolutional input shape
318-
base_model.graph[0].outputs[0].shape = None
319-
onnx_checker.CheckerPass(True)(base_model)
391+
# Set output shape to analyze error due to weights
392+
base_model.graph[0].outputs[0].shape = ir.Shape(("N", 10, 32))
320393

321394
# Apply rule and check it was not applied
322395
tracer = orp.MatchingTracer()
@@ -326,7 +399,7 @@ def test_unsupported_normalize_pad_format(self):
326399
# Check that the error message is the expected one
327400
tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0]
328401
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
329-
self.assertRegex(tracer_match.match_result.reason, "Input shapes are not defined")
402+
self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape")
330403

331404

332405
if __name__ == "__main__":

0 commit comments

Comments
 (0)