diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b2bd38848b4..be3932a8b7f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1101,17 +1101,6 @@ def _compute_expected_mask(mask, top_, left_, height_, width_, size_): torch.testing.assert_close(output_mask, expected_mask) -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_pad_segmentation_mask_on_fixed_input(device): - mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) - - out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) - - expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) - expected_mask[:, 1:-1, 1:-1] = 1 - torch.testing.assert_close(out_mask, expected_mask) - - def _parse_padding(padding): if isinstance(padding, int): return [padding] * 4 @@ -1168,25 +1157,71 @@ def _compute_expected_bbox(bbox, padding_): torch.testing.assert_close(output_boxes, expected_bboxes) +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_pad_segmentation_mask_on_fixed_input(device): + mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) + + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) + + expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) + expected_mask[:, 1:-1, 1:-1] = 1 + torch.testing.assert_close(out_mask, expected_mask) + + @pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]]) -def test_correctness_pad_segmentation_mask(padding): - def _compute_expected_mask(mask, padding_): +@pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"]) +def test_correctness_pad_segmentation_mask(padding, padding_mode): + def _compute_expected_mask(mask, padding_, padding_mode_): h, w = mask.shape[-2], mask.shape[-1] pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) + if any(pad <= 0 for pad in [pad_left, pad_up, pad_right, pad_down]): + raise pytest.UsageError( + "Expected output can be computed on positive pad values only, " + "but F.pad_* can also crop for negative values" + ) + new_h = h + pad_up + pad_down new_w = w + pad_left + pad_right new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w) - expected_mask = torch.zeros(new_shape, dtype=torch.long) - expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask + output = torch.zeros(new_shape, dtype=mask.dtype) + output[..., pad_up:-pad_down, pad_left:-pad_right] = mask + + if padding_mode_ == "edge": + # pad top-left corner, left vertical block, bottom-left corner + output[..., :pad_up, :pad_left] = mask[..., 0, 0].unsqueeze(-1).unsqueeze(-2) + output[..., pad_up:-pad_down, :pad_left] = mask[..., :, 0].unsqueeze(-1) + output[..., -pad_down:, :pad_left] = mask[..., -1, 0].unsqueeze(-1).unsqueeze(-2) + # pad top-right corner, right vertical block, bottom-right corner + output[..., :pad_up, -pad_right:] = mask[..., 0, -1].unsqueeze(-1).unsqueeze(-2) + output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -1].unsqueeze(-1) + output[..., -pad_down:, -pad_right:] = mask[..., -1, -1].unsqueeze(-1).unsqueeze(-2) + # pad top and bottom horizontal blocks + output[..., :pad_up, pad_left:-pad_right] = mask[..., 0, :].unsqueeze(-2) + output[..., -pad_down:, pad_left:-pad_right] = mask[..., -1, :].unsqueeze(-2) + elif padding_mode_ in ("reflect", "symmetric"): + d1 = 1 if padding_mode_ == "reflect" else 0 + d2 = -1 if padding_mode_ == "reflect" else None + both = (-1, -2) + # pad top-left corner, left vertical block, bottom-left corner + output[..., :pad_up, :pad_left] = mask[..., d1 : pad_up + d1, d1 : pad_left + d1].flip(both) + output[..., pad_up:-pad_down, :pad_left] = mask[..., :, d1 : pad_left + d1].flip(-1) + output[..., -pad_down:, :pad_left] = mask[..., -pad_down - d1 : d2, d1 : pad_left + d1].flip(both) + # pad top-right corner, right vertical block, bottom-right corner + output[..., :pad_up, -pad_right:] = mask[..., d1 : pad_up + d1, -pad_right - d1 : d2].flip(both) + output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -pad_right - d1 : d2].flip(-1) + output[..., -pad_down:, -pad_right:] = mask[..., -pad_down - d1 : d2, -pad_right - d1 : d2].flip(both) + # pad top and bottom horizontal blocks + output[..., :pad_up, pad_left:-pad_right] = mask[..., d1 : pad_up + d1, :].flip(-2) + output[..., -pad_down:, pad_left:-pad_right] = mask[..., -pad_down - d1 : d2, :].flip(-2) - return expected_mask + return output for mask in make_segmentation_masks(): - out_mask = F.pad_segmentation_mask(mask, padding, "constant") + out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode) - expected_mask = _compute_expected_mask(mask, padding) + expected_mask = _compute_expected_mask(mask, padding, padding_mode) torch.testing.assert_close(out_mask, expected_mask)