@@ -1101,17 +1101,6 @@ def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
1101
1101
torch .testing .assert_close (output_mask , expected_mask )
1102
1102
1103
1103
1104
- @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1105
- def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1106
- mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1107
-
1108
- out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1109
-
1110
- expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1111
- expected_mask [:, 1 :- 1 , 1 :- 1 ] = 1
1112
- torch .testing .assert_close (out_mask , expected_mask )
1113
-
1114
-
1115
1104
def _parse_padding (padding ):
1116
1105
if isinstance (padding , int ):
1117
1106
return [padding ] * 4
@@ -1168,25 +1157,71 @@ def _compute_expected_bbox(bbox, padding_):
1168
1157
torch .testing .assert_close (output_boxes , expected_bboxes )
1169
1158
1170
1159
1160
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1161
+ def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1162
+ mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1163
+
1164
+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1165
+
1166
+ expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1167
+ expected_mask [:, 1 :- 1 , 1 :- 1 ] = 1
1168
+ torch .testing .assert_close (out_mask , expected_mask )
1169
+
1170
+
1171
1171
@pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
1172
- def test_correctness_pad_segmentation_mask (padding ):
1173
- def _compute_expected_mask (mask , padding_ ):
1172
+ @pytest .mark .parametrize ("padding_mode" , ["constant" , "edge" , "reflect" , "symmetric" ])
1173
+ def test_correctness_pad_segmentation_mask (padding , padding_mode ):
1174
+ def _compute_expected_mask (mask , padding_ , padding_mode_ ):
1174
1175
h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1175
1176
pad_left , pad_up , pad_right , pad_down = _parse_padding (padding_ )
1176
1177
1178
+ if any (pad <= 0 for pad in [pad_left , pad_up , pad_right , pad_down ]):
1179
+ raise pytest .UsageError (
1180
+ "Expected output can be computed on positive pad values only, "
1181
+ "but F.pad_* can also crop for negative values"
1182
+ )
1183
+
1177
1184
new_h = h + pad_up + pad_down
1178
1185
new_w = w + pad_left + pad_right
1179
1186
1180
1187
new_shape = (* mask .shape [:- 2 ], new_h , new_w ) if len (mask .shape ) > 2 else (new_h , new_w )
1181
- expected_mask = torch .zeros (new_shape , dtype = torch .long )
1182
- expected_mask [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1188
+ output = torch .zeros (new_shape , dtype = mask .dtype )
1189
+ output [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1190
+
1191
+ if padding_mode_ == "edge" :
1192
+ # pad top-left corner, left vertical block, bottom-left corner
1193
+ output [..., :pad_up , :pad_left ] = mask [..., 0 , 0 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1194
+ output [..., pad_up :- pad_down , :pad_left ] = mask [..., :, 0 ].unsqueeze (- 1 )
1195
+ output [..., - pad_down :, :pad_left ] = mask [..., - 1 , 0 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1196
+ # pad top-right corner, right vertical block, bottom-right corner
1197
+ output [..., :pad_up , - pad_right :] = mask [..., 0 , - 1 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1198
+ output [..., pad_up :- pad_down , - pad_right :] = mask [..., :, - 1 ].unsqueeze (- 1 )
1199
+ output [..., - pad_down :, - pad_right :] = mask [..., - 1 , - 1 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1200
+ # pad top and bottom horizontal blocks
1201
+ output [..., :pad_up , pad_left :- pad_right ] = mask [..., 0 , :].unsqueeze (- 2 )
1202
+ output [..., - pad_down :, pad_left :- pad_right ] = mask [..., - 1 , :].unsqueeze (- 2 )
1203
+ elif padding_mode_ in ("reflect" , "symmetric" ):
1204
+ d1 = 1 if padding_mode_ == "reflect" else 0
1205
+ d2 = - 1 if padding_mode_ == "reflect" else None
1206
+ both = (- 1 , - 2 )
1207
+ # pad top-left corner, left vertical block, bottom-left corner
1208
+ output [..., :pad_up , :pad_left ] = mask [..., d1 : pad_up + d1 , d1 : pad_left + d1 ].flip (both )
1209
+ output [..., pad_up :- pad_down , :pad_left ] = mask [..., :, d1 : pad_left + d1 ].flip (- 1 )
1210
+ output [..., - pad_down :, :pad_left ] = mask [..., - pad_down - d1 : d2 , d1 : pad_left + d1 ].flip (both )
1211
+ # pad top-right corner, right vertical block, bottom-right corner
1212
+ output [..., :pad_up , - pad_right :] = mask [..., d1 : pad_up + d1 , - pad_right - d1 : d2 ].flip (both )
1213
+ output [..., pad_up :- pad_down , - pad_right :] = mask [..., :, - pad_right - d1 : d2 ].flip (- 1 )
1214
+ output [..., - pad_down :, - pad_right :] = mask [..., - pad_down - d1 : d2 , - pad_right - d1 : d2 ].flip (both )
1215
+ # pad top and bottom horizontal blocks
1216
+ output [..., :pad_up , pad_left :- pad_right ] = mask [..., d1 : pad_up + d1 , :].flip (- 2 )
1217
+ output [..., - pad_down :, pad_left :- pad_right ] = mask [..., - pad_down - d1 : d2 , :].flip (- 2 )
1183
1218
1184
- return expected_mask
1219
+ return output
1185
1220
1186
1221
for mask in make_segmentation_masks ():
1187
- out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1222
+ out_mask = F .pad_segmentation_mask (mask , padding , padding_mode = padding_mode )
1188
1223
1189
- expected_mask = _compute_expected_mask (mask , padding )
1224
+ expected_mask = _compute_expected_mask (mask , padding , padding_mode )
1190
1225
torch .testing .assert_close (out_mask , expected_mask )
1191
1226
1192
1227
0 commit comments