Skip to content

Commit ee557d0

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added tests for other padding modes (#6104)
Summary: * Added tests for other padding modes * Fixed expected mask dtype * Applied comments from review Reviewed By: NicolasHug Differential Revision: D36760913 fbshipit-source-id: ddb12deacb1f6215538ea37e816df670c7b1cb2e
1 parent b5e4cd1 commit ee557d0

File tree

1 file changed

+53
-18
lines changed

1 file changed

+53
-18
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,17 +1101,6 @@ def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
11011101
torch.testing.assert_close(output_mask, expected_mask)
11021102

11031103

1104-
@pytest.mark.parametrize("device", cpu_and_gpu())
1105-
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1106-
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
1107-
1108-
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
1109-
1110-
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
1111-
expected_mask[:, 1:-1, 1:-1] = 1
1112-
torch.testing.assert_close(out_mask, expected_mask)
1113-
1114-
11151104
def _parse_padding(padding):
11161105
if isinstance(padding, int):
11171106
return [padding] * 4
@@ -1168,25 +1157,71 @@ def _compute_expected_bbox(bbox, padding_):
11681157
torch.testing.assert_close(output_boxes, expected_bboxes)
11691158

11701159

1160+
@pytest.mark.parametrize("device", cpu_and_gpu())
1161+
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1162+
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
1163+
1164+
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
1165+
1166+
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
1167+
expected_mask[:, 1:-1, 1:-1] = 1
1168+
torch.testing.assert_close(out_mask, expected_mask)
1169+
1170+
11711171
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
1172-
def test_correctness_pad_segmentation_mask(padding):
1173-
def _compute_expected_mask(mask, padding_):
1172+
@pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"])
1173+
def test_correctness_pad_segmentation_mask(padding, padding_mode):
1174+
def _compute_expected_mask(mask, padding_, padding_mode_):
11741175
h, w = mask.shape[-2], mask.shape[-1]
11751176
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
11761177

1178+
if any(pad <= 0 for pad in [pad_left, pad_up, pad_right, pad_down]):
1179+
raise pytest.UsageError(
1180+
"Expected output can be computed on positive pad values only, "
1181+
"but F.pad_* can also crop for negative values"
1182+
)
1183+
11771184
new_h = h + pad_up + pad_down
11781185
new_w = w + pad_left + pad_right
11791186

11801187
new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
1181-
expected_mask = torch.zeros(new_shape, dtype=torch.long)
1182-
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask
1188+
output = torch.zeros(new_shape, dtype=mask.dtype)
1189+
output[..., pad_up:-pad_down, pad_left:-pad_right] = mask
1190+
1191+
if padding_mode_ == "edge":
1192+
# pad top-left corner, left vertical block, bottom-left corner
1193+
output[..., :pad_up, :pad_left] = mask[..., 0, 0].unsqueeze(-1).unsqueeze(-2)
1194+
output[..., pad_up:-pad_down, :pad_left] = mask[..., :, 0].unsqueeze(-1)
1195+
output[..., -pad_down:, :pad_left] = mask[..., -1, 0].unsqueeze(-1).unsqueeze(-2)
1196+
# pad top-right corner, right vertical block, bottom-right corner
1197+
output[..., :pad_up, -pad_right:] = mask[..., 0, -1].unsqueeze(-1).unsqueeze(-2)
1198+
output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -1].unsqueeze(-1)
1199+
output[..., -pad_down:, -pad_right:] = mask[..., -1, -1].unsqueeze(-1).unsqueeze(-2)
1200+
# pad top and bottom horizontal blocks
1201+
output[..., :pad_up, pad_left:-pad_right] = mask[..., 0, :].unsqueeze(-2)
1202+
output[..., -pad_down:, pad_left:-pad_right] = mask[..., -1, :].unsqueeze(-2)
1203+
elif padding_mode_ in ("reflect", "symmetric"):
1204+
d1 = 1 if padding_mode_ == "reflect" else 0
1205+
d2 = -1 if padding_mode_ == "reflect" else None
1206+
both = (-1, -2)
1207+
# pad top-left corner, left vertical block, bottom-left corner
1208+
output[..., :pad_up, :pad_left] = mask[..., d1 : pad_up + d1, d1 : pad_left + d1].flip(both)
1209+
output[..., pad_up:-pad_down, :pad_left] = mask[..., :, d1 : pad_left + d1].flip(-1)
1210+
output[..., -pad_down:, :pad_left] = mask[..., -pad_down - d1 : d2, d1 : pad_left + d1].flip(both)
1211+
# pad top-right corner, right vertical block, bottom-right corner
1212+
output[..., :pad_up, -pad_right:] = mask[..., d1 : pad_up + d1, -pad_right - d1 : d2].flip(both)
1213+
output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -pad_right - d1 : d2].flip(-1)
1214+
output[..., -pad_down:, -pad_right:] = mask[..., -pad_down - d1 : d2, -pad_right - d1 : d2].flip(both)
1215+
# pad top and bottom horizontal blocks
1216+
output[..., :pad_up, pad_left:-pad_right] = mask[..., d1 : pad_up + d1, :].flip(-2)
1217+
output[..., -pad_down:, pad_left:-pad_right] = mask[..., -pad_down - d1 : d2, :].flip(-2)
11831218

1184-
return expected_mask
1219+
return output
11851220

11861221
for mask in make_segmentation_masks():
1187-
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
1222+
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
11881223

1189-
expected_mask = _compute_expected_mask(mask, padding)
1224+
expected_mask = _compute_expected_mask(mask, padding, padding_mode)
11901225
torch.testing.assert_close(out_mask, expected_mask)
11911226

11921227

0 commit comments

Comments
 (0)