@@ -268,24 +268,73 @@ def test_fuse_pad_into_conv_integer(
268
268
269
269
270
270
class NormalizePadFormatTest (FusePadConvBaseTest ):
271
+ def build_model (
272
+ self ,
273
+ input_shape : ir .Shape ,
274
+ conv_inputs : Sequence [int ],
275
+ conv_attributes : Mapping [str , ir .Attr ] | None = None ,
276
+ infer_shapes = True ,
277
+ ) -> ir .Model :
278
+ tape = ir .tape .Tape ()
279
+ inputs = []
280
+ output_shape = ir .Shape (("?" ,) * len (input_shape ))
281
+
282
+ # Convert conv_inputs to initializers (if needed)
283
+ conv_inputs = list (conv_inputs )
284
+ for idx , x in enumerate (conv_inputs ):
285
+ if isinstance (x , ir .TensorProtocol ):
286
+ conv_inputs [idx ] = tape .initializer (x )
287
+ elif isinstance (x , ir .Value ):
288
+ inputs .append (x )
289
+ elif x is not None :
290
+ raise ValueError (f"Unsupported type for pad input ({ x } ): { type (x )} ." )
291
+
292
+ # Register operations in the tape
293
+ x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (ir .DataType .FLOAT ))
294
+ y = tape .op (
295
+ "Conv" ,
296
+ inputs = [x , * conv_inputs ],
297
+ attributes = conv_attributes ,
298
+ output = ir .Input ("Y" , shape = output_shape , type = x .type ),
299
+ )
300
+
301
+ # Build the model
302
+ ir_model = ir .Model (
303
+ ir .Graph (
304
+ inputs = [x , * inputs ],
305
+ outputs = [y ],
306
+ nodes = tape .nodes ,
307
+ initializers = tape .initializers ,
308
+ opset_imports = {"" : 20 },
309
+ name = "model" ,
310
+ ),
311
+ ir_version = 10 ,
312
+ )
313
+ if len (input_shape ) > 0 and infer_shapes :
314
+ onnx_checker .CheckerPass (True )(ir_model )
315
+ ir_model = shape_inference .infer_shapes (ir_model )
316
+ else :
317
+ onnx_checker .CheckerPass (False )(ir_model )
318
+ return ir_model
319
+
271
320
@parameterized .parameterized .expand (
272
321
[
273
- (strides , kernel_shape , auto_pad )
322
+ (dynamic_shape , strides , kernel_shape , auto_pad )
274
323
for strides , kernel_shape in [((2 , 3 ), (1 , 4 )), ((2 , 1 ), (5 , 2 ))]
275
- for auto_pad in ["SAME_UPPER" , "SAME_LOWER" , "VALID" ]
324
+ for dynamic_shape , auto_pad in [
325
+ (False , "SAME_UPPER" ),
326
+ (False , "SAME_LOWER" ),
327
+ (True , "VALID" ),
328
+ ]
276
329
]
277
330
)
278
- def test_normalize_pad_format (self , strides , kernel_shape , auto_pad ):
279
- pad_inputs = [
280
- ir .tensor ([1 , 1 , 1 , 1 ], name = "pads" ),
281
- None ,
282
- ir .tensor ([2 , 3 ], name = "axes" ),
283
- ]
331
+ def test_normalize_pad_format (self , dynamic_shape , strides , kernel_shape , auto_pad ):
332
+ input_shape = (
333
+ ir .Shape (("N" , "A" , "B" , "C" )) if dynamic_shape else ir .Shape (("N" , 32 , 22 , 27 ))
334
+ )
284
335
base_model = self .build_model (
285
- op_type = "Conv" ,
286
- input_shape = ir .Shape (("N" , 32 , 22 , 27 )),
287
- weight_shape = (32 , 32 , * kernel_shape ),
288
- pad_inputs = pad_inputs ,
336
+ input_shape = input_shape ,
337
+ conv_inputs = [ir .tensor (self .get_conv_weights ((32 , 32 , * kernel_shape )), name = "W" )],
289
338
conv_attributes = {
290
339
"strides" : strides ,
291
340
"auto_pad" : auto_pad ,
@@ -296,27 +345,51 @@ def test_normalize_pad_format(self, strides, kernel_shape, auto_pad):
296
345
297
346
# Apply rule
298
347
count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
299
-
300
- # Check that Pad was fused
301
- self .assertEqual (count , 2 )
302
- self .assertEqual (updated_model .graph .num_nodes (), 1 )
303
348
onnx_checker .CheckerPass (True )(updated_model )
304
349
350
+ # Check conv has changed
351
+ self .assertEqual (count , 1 )
352
+ self .assertEqual (updated_model .graph [0 ].attributes .get_string ("auto_pad" ), "NOTSET" )
353
+
305
354
# Check inference
306
355
inputs = self .rng .random ((1 , 32 , 22 , 27 ), dtype = "float32" )
307
356
testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
308
357
309
- def test_unsupported_normalize_pad_format (self ):
358
+ @parameterized .parameterized .expand (
359
+ [
360
+ (ir .Shape ([]), False , "Input shapes are not defined" ),
361
+ (ir .Shape (("N" , "C" , "A" )), False , "Expected static spatial input shapes" ),
362
+ (ir .Shape (("N" , "C" , 32 )), False , "Expected static spatial output shapes" ),
363
+ ]
364
+ )
365
+ def test_unsupported_normalize_pad_format (self , input_shape , infer_shapes , error_msg ):
310
366
base_model = self .build_model (
311
- op_type = "Conv" ,
312
- input_shape = ir .Shape (("N" , 32 , 14 )),
313
- weight_shape = (32 , 11 , 4 ),
314
- pad_inputs = [ir .tensor ([0 , 0 , 0 , 0 , 0 , 0 ], name = "pads" )],
315
- conv_attributes = {"auto_pad" : "VALID" },
367
+ input_shape = input_shape ,
368
+ conv_inputs = [ir .tensor (np .ones ((32 , 11 , 4 )), name = "W" )],
369
+ conv_attributes = {"auto_pad" : "SAME_UPPER" },
370
+ infer_shapes = infer_shapes ,
371
+ )
372
+
373
+ # Apply rule and check it was not applied
374
+ tracer = orp .MatchingTracer ()
375
+ count = normalize_pad_format_conv .apply_to_model (base_model , tracer = tracer )
376
+ self .assertEqual (count , 0 )
377
+
378
+ # Check that the error message is the expected one
379
+ tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
380
+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
381
+ self .assertRegex (tracer_match .match_result .reason , error_msg )
382
+
383
+ def test_unsupported_normalize_pad_format_on_weights (self ):
384
+ W = ir .Value (name = "W" , shape = ir .Shape ([]), type = ir .TensorType (ir .DataType .FLOAT ))
385
+ base_model = self .build_model (
386
+ input_shape = ir .Shape (("N" , 2 , 32 )),
387
+ conv_inputs = [W ],
388
+ conv_attributes = {"auto_pad" : "SAME_UPPER" },
389
+ infer_shapes = False ,
316
390
)
317
- # Drop convolutional input shape
318
- base_model .graph [0 ].outputs [0 ].shape = None
319
- onnx_checker .CheckerPass (True )(base_model )
391
+ # Set output shape to analyze error due to weights
392
+ base_model .graph [0 ].outputs [0 ].shape = ir .Shape (("N" , 10 , 32 ))
320
393
321
394
# Apply rule and check it was not applied
322
395
tracer = orp .MatchingTracer ()
@@ -326,7 +399,7 @@ def test_unsupported_normalize_pad_format(self):
326
399
# Check that the error message is the expected one
327
400
tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
328
401
self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
329
- self .assertRegex (tracer_match .match_result .reason , "Input shapes are not defined " )
402
+ self .assertRegex (tracer_match .match_result .reason , "same length than kernel_shape " )
330
403
331
404
332
405
if __name__ == "__main__" :
0 commit comments