@@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
95
95
cx = torch .randint (1 , width - 1 , ())
96
96
cy = torch .randint (1 , height - 1 , ())
97
97
w = randint_with_tensor_bounds (1 , torch .minimum (cx , width - cx ) + 1 )
98
- h = randint_with_tensor_bounds (1 , torch .minimum (cy , width - cy ) + 1 )
98
+ h = randint_with_tensor_bounds (1 , torch .minimum (cy , height - cy ) + 1 )
99
99
parts = (cx , cy , w , h )
100
100
else :
101
101
raise pytest .UsageError ()
@@ -413,6 +413,14 @@ def perspective_segmentation_mask():
413
413
)
414
414
415
415
416
+ @register_kernel_info_from_sample_inputs_fn
417
+ def center_crop_bounding_box ():
418
+ for bounding_box , output_size in itertools .product (make_bounding_boxes (), [(24 , 12 ), [16 , 18 ], [46 , 48 ], [12 ]]):
419
+ yield SampleInput (
420
+ bounding_box , format = bounding_box .format , output_size = output_size , image_size = bounding_box .image_size
421
+ )
422
+
423
+
416
424
@pytest .mark .parametrize (
417
425
"kernel" ,
418
426
[
@@ -1273,3 +1281,59 @@ def _compute_expected_mask(mask, pcoeffs_):
1273
1281
else :
1274
1282
expected_masks = expected_masks [0 ]
1275
1283
torch .testing .assert_close (output_mask , expected_masks )
1284
+
1285
+
1286
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1287
+ @pytest .mark .parametrize (
1288
+ "output_size" ,
1289
+ [(18 , 18 ), [18 , 15 ], (16 , 19 ), [12 ], [46 , 48 ]],
1290
+ )
1291
+ def test_correctness_center_crop_bounding_box (device , output_size ):
1292
+ def _compute_expected_bbox (bbox , output_size_ ):
1293
+ format_ = bbox .format
1294
+ image_size_ = bbox .image_size
1295
+ bbox = convert_bounding_box_format (bbox , format_ , features .BoundingBoxFormat .XYWH )
1296
+
1297
+ if len (output_size_ ) == 1 :
1298
+ output_size_ .append (output_size_ [- 1 ])
1299
+
1300
+ cy = int (round ((image_size_ [0 ] - output_size_ [0 ]) * 0.5 ))
1301
+ cx = int (round ((image_size_ [1 ] - output_size_ [1 ]) * 0.5 ))
1302
+ out_bbox = [
1303
+ bbox [0 ].item () - cx ,
1304
+ bbox [1 ].item () - cy ,
1305
+ bbox [2 ].item (),
1306
+ bbox [3 ].item (),
1307
+ ]
1308
+ out_bbox = features .BoundingBox (
1309
+ out_bbox ,
1310
+ format = features .BoundingBoxFormat .XYWH ,
1311
+ image_size = output_size_ ,
1312
+ dtype = bbox .dtype ,
1313
+ device = bbox .device ,
1314
+ )
1315
+ return convert_bounding_box_format (out_bbox , features .BoundingBoxFormat .XYWH , format_ , copy = False )
1316
+
1317
+ for bboxes in make_bounding_boxes (
1318
+ image_sizes = [(32 , 32 ), (24 , 33 ), (32 , 25 )],
1319
+ extra_dims = ((4 ,),),
1320
+ ):
1321
+ bboxes = bboxes .to (device )
1322
+ bboxes_format = bboxes .format
1323
+ bboxes_image_size = bboxes .image_size
1324
+
1325
+ output_boxes = F .center_crop_bounding_box (bboxes , bboxes_format , output_size , bboxes_image_size )
1326
+
1327
+ if bboxes .ndim < 2 :
1328
+ bboxes = [bboxes ]
1329
+
1330
+ expected_bboxes = []
1331
+ for bbox in bboxes :
1332
+ bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
1333
+ expected_bboxes .append (_compute_expected_bbox (bbox , output_size ))
1334
+
1335
+ if len (expected_bboxes ) > 1 :
1336
+ expected_bboxes = torch .stack (expected_bboxes )
1337
+ else :
1338
+ expected_bboxes = expected_bboxes [0 ]
1339
+ torch .testing .assert_close (output_boxes , expected_bboxes )
0 commit comments