@@ -266,15 +266,15 @@ def affine_bounding_box():
266
266
267
267
@register_kernel_info_from_sample_inputs_fn
268
268
def affine_segmentation_mask ():
269
- for image , angle , translate , scale , shear in itertools .product (
269
+ for mask , angle , translate , scale , shear in itertools .product (
270
270
make_segmentation_masks (extra_dims = ((), (4 ,))),
271
271
[- 87 , 15 , 90 ], # angle
272
272
[5 , - 5 ], # translate
273
273
[0.77 , 1.27 ], # scale
274
274
[0 , 12 ], # shear
275
275
):
276
276
yield SampleInput (
277
- image ,
277
+ mask ,
278
278
angle = angle ,
279
279
translate = (translate , translate ),
280
280
scale = scale ,
@@ -285,8 +285,12 @@ def affine_segmentation_mask():
285
285
@register_kernel_info_from_sample_inputs_fn
286
286
def rotate_bounding_box ():
287
287
for bounding_box , angle , expand , center in itertools .product (
288
- make_bounding_boxes (), [- 87 , 15 , 90 ], [True , False ], [None , [12 , 23 ]] # angle # expand # center
288
+ make_bounding_boxes (), [- 87 , 15 , 90 ], [True , False ], [None , [12 , 23 ]]
289
289
):
290
+ if center is not None and expand :
291
+ # Skip warning: The provided center argument is ignored if expand is True
292
+ continue
293
+
290
294
yield SampleInput (
291
295
bounding_box ,
292
296
format = bounding_box .format ,
@@ -297,6 +301,26 @@ def rotate_bounding_box():
297
301
)
298
302
299
303
304
+ @register_kernel_info_from_sample_inputs_fn
305
+ def rotate_segmentation_mask ():
306
+ for mask , angle , expand , center in itertools .product (
307
+ make_segmentation_masks (extra_dims = ((), (4 ,))),
308
+ [- 87 , 15 , 90 ], # angle
309
+ [True , False ], # expand
310
+ [None , [12 , 23 ]], # center
311
+ ):
312
+ if center is not None and expand :
313
+ # Skip warning: The provided center argument is ignored if expand is True
314
+ continue
315
+
316
+ yield SampleInput (
317
+ mask ,
318
+ angle = angle ,
319
+ expand = expand ,
320
+ center = center ,
321
+ )
322
+
323
+
300
324
@pytest .mark .parametrize (
301
325
"kernel" ,
302
326
[
@@ -411,8 +435,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
411
435
center = center ,
412
436
)
413
437
414
- if center is None :
415
- center = [s // 2 for s in bboxes_image_size [::- 1 ]]
438
+ center_ = center
439
+ if center_ is None :
440
+ center_ = [s * 0.5 for s in bboxes_image_size [::- 1 ]]
416
441
417
442
if bboxes .ndim < 2 :
418
443
bboxes = [bboxes ]
@@ -421,7 +446,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
421
446
for bbox in bboxes :
422
447
bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
423
448
expected_bboxes .append (
424
- _compute_expected_bbox (bbox , angle , (translate , translate ), scale , (shear , shear ), center )
449
+ _compute_expected_bbox (bbox , angle , (translate , translate ), scale , (shear , shear ), center_ )
425
450
)
426
451
if len (expected_bboxes ) > 1 :
427
452
expected_bboxes = torch .stack (expected_bboxes )
@@ -510,8 +535,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
510
535
shear = (shear , shear ),
511
536
center = center ,
512
537
)
513
- if center is None :
514
- center = [s // 2 for s in mask .shape [- 2 :][::- 1 ]]
538
+
539
+ center_ = center
540
+ if center_ is None :
541
+ center_ = [s * 0.5 for s in mask .shape [- 2 :][::- 1 ]]
515
542
516
543
if mask .ndim < 4 :
517
544
masks = [mask ]
@@ -520,7 +547,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
520
547
521
548
expected_masks = []
522
549
for mask in masks :
523
- expected_mask = _compute_expected_mask (mask , angle , (translate , translate ), scale , (shear , shear ), center )
550
+ expected_mask = _compute_expected_mask (mask , angle , (translate , translate ), scale , (shear , shear ), center_ )
524
551
expected_masks .append (expected_mask )
525
552
if len (expected_masks ) > 1 :
526
553
expected_masks = torch .stack (expected_masks )
@@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
550
577
551
578
552
579
@pytest .mark .parametrize ("angle" , range (- 90 , 90 , 56 ))
553
- @pytest .mark .parametrize ("expand" , [True , False ])
554
- @pytest .mark .parametrize ("center" , [None , (12 , 14 )])
580
+ @pytest .mark .parametrize ("expand, center" , [(True , None ), (False , None ), (False , (12 , 14 ))])
555
581
def test_correctness_rotate_bounding_box (angle , expand , center ):
556
582
def _compute_expected_bbox (bbox , angle_ , expand_ , center_ ):
557
583
affine_matrix = _compute_affine_matrix (angle_ , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ], center_ )
@@ -620,16 +646,17 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
620
646
center = center ,
621
647
)
622
648
623
- if center is None :
624
- center = [s // 2 for s in bboxes_image_size [::- 1 ]]
649
+ center_ = center
650
+ if center_ is None :
651
+ center_ = [s * 0.5 for s in bboxes_image_size [::- 1 ]]
625
652
626
653
if bboxes .ndim < 2 :
627
654
bboxes = [bboxes ]
628
655
629
656
expected_bboxes = []
630
657
for bbox in bboxes :
631
658
bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
632
- expected_bboxes .append (_compute_expected_bbox (bbox , - angle , expand , center ))
659
+ expected_bboxes .append (_compute_expected_bbox (bbox , - angle , expand , center_ ))
633
660
if len (expected_bboxes ) > 1 :
634
661
expected_bboxes = torch .stack (expected_bboxes )
635
662
else :
@@ -638,7 +665,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
638
665
639
666
640
667
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
641
- @pytest .mark .parametrize ("expand" , [False ]) # expand=True does not match D2, analysis in progress
668
+ @pytest .mark .parametrize ("expand" , [False ]) # expand=True does not match D2
642
669
def test_correctness_rotate_bounding_box_on_fixed_input (device , expand ):
643
670
# Check transformation against known expected output
644
671
image_size = (64 , 64 )
@@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
689
716
)
690
717
691
718
torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
719
+
720
+
721
+ @pytest .mark .parametrize ("angle" , range (- 90 , 90 , 37 ))
722
+ @pytest .mark .parametrize ("expand, center" , [(True , None ), (False , None ), (False , (12 , 14 ))])
723
+ def test_correctness_rotate_segmentation_mask (angle , expand , center ):
724
+ def _compute_expected_mask (mask , angle_ , expand_ , center_ ):
725
+ assert mask .ndim == 3 and mask .shape [0 ] == 1
726
+ image_size = mask .shape [- 2 :]
727
+ affine_matrix = _compute_affine_matrix (angle_ , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ], center_ )
728
+ inv_affine_matrix = np .linalg .inv (affine_matrix )
729
+
730
+ if expand_ :
731
+ # Pillow implementation on how to perform expand:
732
+ # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069
733
+ height , width = image_size
734
+ points = np .array (
735
+ [
736
+ [0.0 , 0.0 , 1.0 ],
737
+ [0.0 , 1.0 * height , 1.0 ],
738
+ [1.0 * width , 1.0 * height , 1.0 ],
739
+ [1.0 * width , 0.0 , 1.0 ],
740
+ ]
741
+ )
742
+ new_points = points @ inv_affine_matrix .T
743
+ min_vals = np .min (new_points , axis = 0 )[:2 ]
744
+ max_vals = np .max (new_points , axis = 0 )[:2 ]
745
+ cmax = np .ceil (np .trunc (max_vals * 1e4 ) * 1e-4 )
746
+ cmin = np .floor (np .trunc ((min_vals + 1e-8 ) * 1e4 ) * 1e-4 )
747
+ new_width , new_height = (cmax - cmin ).astype ("int32" ).tolist ()
748
+ tr = np .array ([- (new_width - width ) / 2.0 , - (new_height - height ) / 2.0 , 1.0 ]) @ inv_affine_matrix .T
749
+
750
+ inv_affine_matrix [:2 , 2 ] = tr [:2 ]
751
+ image_size = [new_height , new_width ]
752
+
753
+ inv_affine_matrix = inv_affine_matrix [:2 , :]
754
+ expected_mask = torch .zeros (1 , * image_size , dtype = mask .dtype )
755
+
756
+ for out_y in range (expected_mask .shape [1 ]):
757
+ for out_x in range (expected_mask .shape [2 ]):
758
+ output_pt = np .array ([out_x + 0.5 , out_y + 0.5 , 1.0 ])
759
+ input_pt = np .floor (np .dot (inv_affine_matrix , output_pt )).astype (np .int32 )
760
+ in_x , in_y = input_pt [:2 ]
761
+ if 0 <= in_x < mask .shape [2 ] and 0 <= in_y < mask .shape [1 ]:
762
+ expected_mask [0 , out_y , out_x ] = mask [0 , in_y , in_x ]
763
+ return expected_mask .to (mask .device )
764
+
765
+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
766
+ output_mask = F .rotate_segmentation_mask (
767
+ mask ,
768
+ angle = angle ,
769
+ expand = expand ,
770
+ center = center ,
771
+ )
772
+
773
+ center_ = center
774
+ if center_ is None :
775
+ center_ = [s * 0.5 for s in mask .shape [- 2 :][::- 1 ]]
776
+
777
+ if mask .ndim < 4 :
778
+ masks = [mask ]
779
+ else :
780
+ masks = [m for m in mask ]
781
+
782
+ expected_masks = []
783
+ for mask in masks :
784
+ expected_mask = _compute_expected_mask (mask , - angle , expand , center_ )
785
+ expected_masks .append (expected_mask )
786
+ if len (expected_masks ) > 1 :
787
+ expected_masks = torch .stack (expected_masks )
788
+ else :
789
+ expected_masks = expected_masks [0 ]
790
+ torch .testing .assert_close (output_mask , expected_masks )
791
+
792
+
793
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
794
+ def test_correctness_rotate_segmentation_mask_on_fixed_input (device ):
795
+ # Check transformation against known expected output and CPU/CUDA devices
796
+
797
+ # Create a fixed input segmentation mask with 2 square masks
798
+ # in top-left, bottom-left corners
799
+ mask = torch .zeros (1 , 32 , 32 , dtype = torch .long , device = device )
800
+ mask [0 , 2 :10 , 2 :10 ] = 1
801
+ mask [0 , 32 - 9 : 32 - 3 , 3 :9 ] = 2
802
+
803
+ # Rotate 90 degrees
804
+ expected_mask = torch .rot90 (mask , k = 1 , dims = (- 2 , - 1 ))
805
+ out_mask = F .rotate_segmentation_mask (mask , 90 , expand = False )
806
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments