@@ -382,6 +382,15 @@ def pad_segmentation_mask():
382
382
yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
383
383
384
384
385
+ @register_kernel_info_from_sample_inputs_fn
386
+ def pad_bounding_box ():
387
+ for bounding_box , padding in itertools .product (
388
+ make_bounding_boxes (),
389
+ [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]],
390
+ ):
391
+ yield SampleInput (bounding_box , padding = padding , format = bounding_box .format )
392
+
393
+
385
394
@register_kernel_info_from_sample_inputs_fn
386
395
def perspective_bounding_box ():
387
396
for bounding_box , perspective_coeffs in itertools .product (
@@ -1103,22 +1112,67 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1103
1112
torch .testing .assert_close (out_mask , expected_mask )
1104
1113
1105
1114
1115
+ def _parse_padding (padding ):
1116
+ if isinstance (padding , int ):
1117
+ return [padding ] * 4
1118
+ if isinstance (padding , list ):
1119
+ if len (padding ) == 1 :
1120
+ return padding * 4
1121
+ if len (padding ) == 2 :
1122
+ return padding * 2 # [left, up, right, down]
1123
+
1124
+ return padding
1125
+
1126
+
1127
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1128
+ @pytest .mark .parametrize ("padding" , [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]])
1129
+ def test_correctness_pad_bounding_box (device , padding ):
1130
+ def _compute_expected_bbox (bbox , padding_ ):
1131
+ pad_left , pad_up , _ , _ = _parse_padding (padding_ )
1132
+
1133
+ bbox_format = bbox .format
1134
+ bbox_dtype = bbox .dtype
1135
+ bbox = convert_bounding_box_format (bbox , old_format = bbox_format , new_format = features .BoundingBoxFormat .XYXY )
1136
+
1137
+ bbox [0 ::2 ] += pad_left
1138
+ bbox [1 ::2 ] += pad_up
1139
+
1140
+ bbox = convert_bounding_box_format (
1141
+ bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox_format , copy = False
1142
+ )
1143
+ if bbox .dtype != bbox_dtype :
1144
+ # Temporary cast to original dtype
1145
+ # e.g. float32 -> int
1146
+ bbox = bbox .to (bbox_dtype )
1147
+ return bbox
1148
+
1149
+ for bboxes in make_bounding_boxes ():
1150
+ bboxes = bboxes .to (device )
1151
+ bboxes_format = bboxes .format
1152
+ bboxes_image_size = bboxes .image_size
1153
+
1154
+ output_boxes = F .pad_bounding_box (bboxes , padding , format = bboxes_format )
1155
+
1156
+ if bboxes .ndim < 2 :
1157
+ bboxes = [bboxes ]
1158
+
1159
+ expected_bboxes = []
1160
+ for bbox in bboxes :
1161
+ bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
1162
+ expected_bboxes .append (_compute_expected_bbox (bbox , padding ))
1163
+
1164
+ if len (expected_bboxes ) > 1 :
1165
+ expected_bboxes = torch .stack (expected_bboxes )
1166
+ else :
1167
+ expected_bboxes = expected_bboxes [0 ]
1168
+ torch .testing .assert_close (output_boxes , expected_bboxes )
1169
+
1170
+
1106
1171
@pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
1107
1172
def test_correctness_pad_segmentation_mask (padding ):
1108
- def _compute_expected_mask ():
1109
- def parse_padding ():
1110
- if isinstance (padding , int ):
1111
- return [padding ] * 4
1112
- if isinstance (padding , list ):
1113
- if len (padding ) == 1 :
1114
- return padding * 4
1115
- if len (padding ) == 2 :
1116
- return padding * 2 # [left, up, right, down]
1117
-
1118
- return padding
1119
-
1173
+ def _compute_expected_mask (mask , padding_ ):
1120
1174
h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1121
- pad_left , pad_up , pad_right , pad_down = parse_padding ( )
1175
+ pad_left , pad_up , pad_right , pad_down = _parse_padding ( padding_ )
1122
1176
1123
1177
new_h = h + pad_up + pad_down
1124
1178
new_w = w + pad_left + pad_right
@@ -1132,7 +1186,7 @@ def parse_padding():
1132
1186
for mask in make_segmentation_masks ():
1133
1187
out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1134
1188
1135
- expected_mask = _compute_expected_mask ()
1189
+ expected_mask = _compute_expected_mask (mask , padding )
1136
1190
torch .testing .assert_close (out_mask , expected_mask )
1137
1191
1138
1192
0 commit comments