@@ -332,6 +332,20 @@ def crop_bounding_box():
332
332
)
333
333
334
334
335
+ @register_kernel_info_from_sample_inputs_fn
336
+ def crop_segmentation_mask ():
337
+ for mask , top , left , height , width in itertools .product (
338
+ make_segmentation_masks (), [- 8 , 0 , 9 ], [- 8 , 0 , 9 ], [12 , 20 ], [12 , 20 ]
339
+ ):
340
+ yield SampleInput (
341
+ mask ,
342
+ top = top ,
343
+ left = left ,
344
+ height = height ,
345
+ width = width ,
346
+ )
347
+
348
+
335
349
@pytest .mark .parametrize (
336
350
"kernel" ,
337
351
[
@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
860
874
)
861
875
862
876
torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
877
+
878
+
879
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
880
+ @pytest .mark .parametrize (
881
+ "top, left, height, width" ,
882
+ [
883
+ [4 , 6 , 30 , 40 ],
884
+ [- 8 , 6 , 70 , 40 ],
885
+ [- 8 , - 6 , 70 , 8 ],
886
+ ],
887
+ )
888
+ def test_correctness_crop_segmentation_mask (device , top , left , height , width ):
889
+ def _compute_expected_mask (mask , top_ , left_ , height_ , width_ ):
890
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
891
+ if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w :
892
+ expected = mask [..., top_ : top_ + height_ , left_ : left_ + width_ ]
893
+ else :
894
+ # Create output mask
895
+ expected_shape = mask .shape [:- 2 ] + (height_ , width_ )
896
+ expected = torch .zeros (expected_shape , device = mask .device , dtype = mask .dtype )
897
+
898
+ out_y1 = abs (top_ ) if top_ < 0 else 0
899
+ out_y2 = h - top_ if top_ + height_ >= h else height_
900
+ out_x1 = abs (left_ ) if left_ < 0 else 0
901
+ out_x2 = w - left_ if left_ + width_ >= w else width_
902
+
903
+ in_y1 = 0 if top_ < 0 else top_
904
+ in_y2 = h if top_ + height_ >= h else top_ + height_
905
+ in_x1 = 0 if left_ < 0 else left_
906
+ in_x2 = w if left_ + width_ >= w else left_ + width_
907
+ # Paste input mask into output
908
+ expected [..., out_y1 :out_y2 , out_x1 :out_x2 ] = mask [..., in_y1 :in_y2 , in_x1 :in_x2 ]
909
+
910
+ return expected
911
+
912
+ for mask in make_segmentation_masks ():
913
+ if mask .device != torch .device (device ):
914
+ mask = mask .to (device )
915
+ output_mask = F .crop_segmentation_mask (mask , top , left , height , width )
916
+ expected_mask = _compute_expected_mask (mask , top , left , height , width )
917
+ torch .testing .assert_close (output_mask , expected_mask )
0 commit comments