@@ -68,17 +68,22 @@ def test_scripted_vs_eager(self, info, args_kwargs, device):
68
68
69
69
assert_close (actual , expected , ** info .closeness_kwargs )
70
70
71
- def _unbind_batch_dims (self , batched_tensor , * , data_dims ):
72
- if batched_tensor .ndim == data_dims :
73
- return batched_tensor
74
-
75
- return [self ._unbind_batch_dims (t , data_dims = data_dims ) for t in batched_tensor .unbind (0 )]
71
+ def _unbatch (self , batch , * , data_dims ):
72
+ if isinstance (batch , torch .Tensor ):
73
+ batched_tensor = batch
74
+ metadata = ()
75
+ else :
76
+ batched_tensor , * metadata = batch
76
77
77
- def _stack_batch_dims (self , unbound_tensor ):
78
- if isinstance (unbound_tensor [0 ], torch .Tensor ):
79
- return torch .stack (unbound_tensor )
78
+ if batched_tensor .ndim == data_dims :
79
+ return batch
80
80
81
- return torch .stack ([self ._stack_batch_dims (t ) for t in unbound_tensor ])
81
+ return [
82
+ self ._unbatch (unbatched , data_dims = data_dims )
83
+ for unbatched in (
84
+ batched_tensor .unbind (0 ) if not metadata else [(t , * metadata ) for t in batched_tensor .unbind (0 )]
85
+ )
86
+ ]
82
87
83
88
@sample_inputs
84
89
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
@@ -106,11 +111,11 @@ def test_batched_vs_single(self, info, args_kwargs, device):
106
111
elif not all (batched_input .shape [:- data_dims ]):
107
112
pytest .skip ("Input has a degenerate batch shape." )
108
113
109
- actual = info .kernel (batched_input , * other_args , ** kwargs )
114
+ batched_output = info .kernel (batched_input , * other_args , ** kwargs )
115
+ actual = self ._unbatch (batched_output , data_dims = data_dims )
110
116
111
- single_inputs = self ._unbind_batch_dims (batched_input , data_dims = data_dims )
112
- single_outputs = tree_map (lambda single_input : info .kernel (single_input , * other_args , ** kwargs ), single_inputs )
113
- expected = self ._stack_batch_dims (single_outputs )
117
+ single_inputs = self ._unbatch (batched_input , data_dims = data_dims )
118
+ expected = tree_map (lambda single_input : info .kernel (single_input , * other_args , ** kwargs ), single_inputs )
114
119
115
120
assert_close (actual , expected , ** info .closeness_kwargs )
116
121
@@ -123,9 +128,9 @@ def test_no_inplace(self, info, args_kwargs, device):
123
128
pytest .skip ("The input has a degenerate shape." )
124
129
125
130
input_version = input ._version
126
- output = info .kernel (input , * other_args , ** kwargs )
131
+ info .kernel (input , * other_args , ** kwargs )
127
132
128
- assert output is not input or output ._version == input_version
133
+ assert input ._version == input_version
129
134
130
135
@sample_inputs
131
136
@needs_cuda
@@ -144,6 +149,9 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device):
144
149
(input , * other_args ), kwargs = args_kwargs .load (device )
145
150
146
151
output = info .kernel (input , * other_args , ** kwargs )
152
+ # Most kernels just return a tensor, but some also return some additional metadata
153
+ if not isinstance (output , torch .Tensor ):
154
+ output , * _ = output
147
155
148
156
assert output .dtype == input .dtype
149
157
assert output .device == input .device
@@ -324,7 +332,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
324
332
affine_matrix = _compute_affine_matrix (angle_ , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ], center_ )
325
333
affine_matrix = affine_matrix [:2 , :]
326
334
327
- image_size = bbox .image_size
335
+ height , width = bbox .image_size
328
336
bbox_xyxy = convert_format_bounding_box (
329
337
bbox , old_format = bbox .format , new_format = features .BoundingBoxFormat .XYXY
330
338
)
@@ -336,9 +344,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
336
344
[bbox_xyxy [2 ].item (), bbox_xyxy [3 ].item (), 1.0 ],
337
345
# image frame
338
346
[0.0 , 0.0 , 1.0 ],
339
- [0.0 , image_size [ 0 ] , 1.0 ],
340
- [image_size [ 1 ], image_size [ 0 ] , 1.0 ],
341
- [image_size [ 1 ] , 0.0 , 1.0 ],
347
+ [0.0 , height , 1.0 ],
348
+ [width , height , 1.0 ],
349
+ [width , 0.0 , 1.0 ],
342
350
]
343
351
)
344
352
transformed_points = np .matmul (points , affine_matrix .T )
@@ -356,18 +364,21 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
356
364
out_bbox [2 ] -= tr_x
357
365
out_bbox [3 ] -= tr_y
358
366
359
- # image_size should be updated, but it is OK here to skip its computation
360
- # as we do not compute it in F.rotate_bounding_box
367
+ height = int ( height - 2 * tr_y )
368
+ width = int ( width - 2 * tr_x )
361
369
362
370
out_bbox = features .BoundingBox (
363
371
out_bbox ,
364
372
format = features .BoundingBoxFormat .XYXY ,
365
- image_size = image_size ,
373
+ image_size = ( height , width ) ,
366
374
dtype = bbox .dtype ,
367
375
device = bbox .device ,
368
376
)
369
- return convert_format_bounding_box (
370
- out_bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox .format , copy = False
377
+ return (
378
+ convert_format_bounding_box (
379
+ out_bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox .format , copy = False
380
+ ),
381
+ (height , width ),
371
382
)
372
383
373
384
image_size = (32 , 38 )
@@ -376,7 +387,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
376
387
bboxes_format = bboxes .format
377
388
bboxes_image_size = bboxes .image_size
378
389
379
- output_bboxes = F .rotate_bounding_box (
390
+ output_bboxes , output_image_size = F .rotate_bounding_box (
380
391
bboxes ,
381
392
bboxes_format ,
382
393
image_size = bboxes_image_size ,
@@ -395,12 +406,14 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
395
406
expected_bboxes = []
396
407
for bbox in bboxes :
397
408
bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
398
- expected_bboxes .append (_compute_expected_bbox (bbox , - angle , expand , center_ ))
409
+ expected_bbox , expected_image_size = _compute_expected_bbox (bbox , - angle , expand , center_ )
410
+ expected_bboxes .append (expected_bbox )
399
411
if len (expected_bboxes ) > 1 :
400
412
expected_bboxes = torch .stack (expected_bboxes )
401
413
else :
402
414
expected_bboxes = expected_bboxes [0 ]
403
415
torch .testing .assert_close (output_bboxes , expected_bboxes , atol = 1 , rtol = 0 )
416
+ torch .testing .assert_close (output_image_size , expected_image_size , atol = 1 , rtol = 0 )
404
417
405
418
406
419
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
@@ -445,7 +458,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
445
458
[18.36396103 , 1.07968978 , 46.64823228 , 29.36396103 ],
446
459
]
447
460
448
- output_boxes = F .rotate_bounding_box (
461
+ output_boxes , _ = F .rotate_bounding_box (
449
462
in_boxes ,
450
463
in_boxes .format ,
451
464
in_boxes .image_size ,
@@ -510,17 +523,20 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
510
523
if format != features .BoundingBoxFormat .XYXY :
511
524
in_boxes = convert_format_bounding_box (in_boxes , features .BoundingBoxFormat .XYXY , format )
512
525
513
- output_boxes = F .crop_bounding_box (
526
+ output_boxes , output_image_size = F .crop_bounding_box (
514
527
in_boxes ,
515
528
format ,
516
529
top ,
517
530
left ,
531
+ size [0 ],
532
+ size [1 ],
518
533
)
519
534
520
535
if format != features .BoundingBoxFormat .XYXY :
521
536
output_boxes = convert_format_bounding_box (output_boxes , format , features .BoundingBoxFormat .XYXY )
522
537
523
538
torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
539
+ torch .testing .assert_close (output_image_size , size )
524
540
525
541
526
542
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
@@ -585,12 +601,13 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
585
601
if format != features .BoundingBoxFormat .XYXY :
586
602
in_boxes = convert_format_bounding_box (in_boxes , features .BoundingBoxFormat .XYXY , format )
587
603
588
- output_boxes = F .resized_crop_bounding_box (in_boxes , format , top , left , height , width , size )
604
+ output_boxes , output_image_size = F .resized_crop_bounding_box (in_boxes , format , top , left , height , width , size )
589
605
590
606
if format != features .BoundingBoxFormat .XYXY :
591
607
output_boxes = convert_format_bounding_box (output_boxes , format , features .BoundingBoxFormat .XYXY )
592
608
593
609
torch .testing .assert_close (output_boxes , expected_bboxes )
610
+ torch .testing .assert_close (output_image_size , size )
594
611
595
612
596
613
def _parse_padding (padding ):
@@ -627,12 +644,21 @@ def _compute_expected_bbox(bbox, padding_):
627
644
bbox = bbox .to (bbox_dtype )
628
645
return bbox
629
646
647
+ def _compute_expected_image_size (bbox , padding_ ):
648
+ pad_left , pad_up , pad_right , pad_down = _parse_padding (padding_ )
649
+ height , width = bbox .image_size
650
+ return height + pad_up + pad_down , width + pad_left + pad_right
651
+
630
652
for bboxes in make_bounding_boxes ():
631
653
bboxes = bboxes .to (device )
632
654
bboxes_format = bboxes .format
633
655
bboxes_image_size = bboxes .image_size
634
656
635
- output_boxes = F .pad_bounding_box (bboxes , format = bboxes_format , padding = padding )
657
+ output_boxes , output_image_size = F .pad_bounding_box (
658
+ bboxes , format = bboxes_format , image_size = bboxes_image_size , padding = padding
659
+ )
660
+
661
+ torch .testing .assert_close (output_image_size , _compute_expected_image_size (bboxes , padding ))
636
662
637
663
if bboxes .ndim < 2 or bboxes .shape [0 ] == 0 :
638
664
bboxes = [bboxes ]
@@ -781,7 +807,9 @@ def _compute_expected_bbox(bbox, output_size_):
781
807
bboxes_format = bboxes .format
782
808
bboxes_image_size = bboxes .image_size
783
809
784
- output_boxes = F .center_crop_bounding_box (bboxes , bboxes_format , bboxes_image_size , output_size )
810
+ output_boxes , output_image_size = F .center_crop_bounding_box (
811
+ bboxes , bboxes_format , bboxes_image_size , output_size
812
+ )
785
813
786
814
if bboxes .ndim < 2 :
787
815
bboxes = [bboxes ]
@@ -796,6 +824,7 @@ def _compute_expected_bbox(bbox, output_size_):
796
824
else :
797
825
expected_bboxes = expected_bboxes [0 ]
798
826
torch .testing .assert_close (output_boxes , expected_bboxes )
827
+ torch .testing .assert_close (output_image_size , output_size )
799
828
800
829
801
830
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
0 commit comments