Skip to content

Commit c66da5e

Browse files
authored
Added crop_segmentation_mask op (#5851)
* Added `crop_segmentation_mask` op * Fixed failed mypy
1 parent ca26537 commit c66da5e

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,20 @@ def crop_bounding_box():
332332
)
333333

334334

335+
@register_kernel_info_from_sample_inputs_fn
336+
def crop_segmentation_mask():
337+
for mask, top, left, height, width in itertools.product(
338+
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
339+
):
340+
yield SampleInput(
341+
mask,
342+
top=top,
343+
left=left,
344+
height=height,
345+
width=width,
346+
)
347+
348+
335349
@pytest.mark.parametrize(
336350
"kernel",
337351
[
@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
860874
)
861875

862876
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
877+
878+
879+
@pytest.mark.parametrize("device", cpu_and_gpu())
880+
@pytest.mark.parametrize(
881+
"top, left, height, width",
882+
[
883+
[4, 6, 30, 40],
884+
[-8, 6, 70, 40],
885+
[-8, -6, 70, 8],
886+
],
887+
)
888+
def test_correctness_crop_segmentation_mask(device, top, left, height, width):
889+
def _compute_expected_mask(mask, top_, left_, height_, width_):
890+
h, w = mask.shape[-2], mask.shape[-1]
891+
if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w:
892+
expected = mask[..., top_ : top_ + height_, left_ : left_ + width_]
893+
else:
894+
# Create output mask
895+
expected_shape = mask.shape[:-2] + (height_, width_)
896+
expected = torch.zeros(expected_shape, device=mask.device, dtype=mask.dtype)
897+
898+
out_y1 = abs(top_) if top_ < 0 else 0
899+
out_y2 = h - top_ if top_ + height_ >= h else height_
900+
out_x1 = abs(left_) if left_ < 0 else 0
901+
out_x2 = w - left_ if left_ + width_ >= w else width_
902+
903+
in_y1 = 0 if top_ < 0 else top_
904+
in_y2 = h if top_ + height_ >= h else top_ + height_
905+
in_x1 = 0 if left_ < 0 else left_
906+
in_x2 = w if left_ + width_ >= w else left_ + width_
907+
# Paste input mask into output
908+
expected[..., out_y1:out_y2, out_x1:out_x2] = mask[..., in_y1:in_y2, in_x1:in_x2]
909+
910+
return expected
911+
912+
for mask in make_segmentation_masks():
913+
if mask.device != torch.device(device):
914+
mask = mask.to(device)
915+
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
916+
expected_mask = _compute_expected_mask(mask, top, left, height, width)
917+
torch.testing.assert_close(output_mask, expected_mask)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
crop_bounding_box,
6464
crop_image_tensor,
6565
crop_image_pil,
66+
crop_segmentation_mask,
6667
perspective_image_tensor,
6768
perspective_image_pil,
6869
vertical_flip_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,10 @@ def crop_bounding_box(
440440
).view(shape)
441441

442442

443+
def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
444+
return crop_image_tensor(img, top, left, height, width)
445+
446+
443447
def perspective_image_tensor(
444448
img: torch.Tensor,
445449
perspective_coeffs: List[float],

0 commit comments

Comments
 (0)