Skip to content

Commit db4cc0b

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added functional rotate_segmentation_mask op (#5692)
Summary: * Added functional affine_bounding_box op with tests * Updated comments and added another test case * Update _geometry.py * Added affine_segmentation_mask with tests * Fixed device mismatch issue Added a cude/cpu test Reduced the number of test samples * Added test_correctness_affine_segmentation_mask_on_fixed_input * Updates according to the review * Replaced [None, ...] by [None, :] * Adressed review comments * Fixed formatting and more updates according to the review * Fixed bad merge * WIP * Fixed tests * Updated warning message Reviewed By: NicolasHug Differential Revision: D35393159 fbshipit-source-id: e6950844f1e0066f879019a72001a673f501281e
1 parent 909f1d6 commit db4cc0b

File tree

3 files changed

+152
-17
lines changed

3 files changed

+152
-17
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 130 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ def affine_bounding_box():
266266

267267
@register_kernel_info_from_sample_inputs_fn
268268
def affine_segmentation_mask():
269-
for image, angle, translate, scale, shear in itertools.product(
269+
for mask, angle, translate, scale, shear in itertools.product(
270270
make_segmentation_masks(extra_dims=((), (4,))),
271271
[-87, 15, 90], # angle
272272
[5, -5], # translate
273273
[0.77, 1.27], # scale
274274
[0, 12], # shear
275275
):
276276
yield SampleInput(
277-
image,
277+
mask,
278278
angle=angle,
279279
translate=(translate, translate),
280280
scale=scale,
@@ -285,8 +285,12 @@ def affine_segmentation_mask():
285285
@register_kernel_info_from_sample_inputs_fn
286286
def rotate_bounding_box():
287287
for bounding_box, angle, expand, center in itertools.product(
288-
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center
288+
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]]
289289
):
290+
if center is not None and expand:
291+
# Skip warning: The provided center argument is ignored if expand is True
292+
continue
293+
290294
yield SampleInput(
291295
bounding_box,
292296
format=bounding_box.format,
@@ -297,6 +301,26 @@ def rotate_bounding_box():
297301
)
298302

299303

304+
@register_kernel_info_from_sample_inputs_fn
305+
def rotate_segmentation_mask():
306+
for mask, angle, expand, center in itertools.product(
307+
make_segmentation_masks(extra_dims=((), (4,))),
308+
[-87, 15, 90], # angle
309+
[True, False], # expand
310+
[None, [12, 23]], # center
311+
):
312+
if center is not None and expand:
313+
# Skip warning: The provided center argument is ignored if expand is True
314+
continue
315+
316+
yield SampleInput(
317+
mask,
318+
angle=angle,
319+
expand=expand,
320+
center=center,
321+
)
322+
323+
300324
@pytest.mark.parametrize(
301325
"kernel",
302326
[
@@ -411,8 +435,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
411435
center=center,
412436
)
413437

414-
if center is None:
415-
center = [s // 2 for s in bboxes_image_size[::-1]]
438+
center_ = center
439+
if center_ is None:
440+
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]
416441

417442
if bboxes.ndim < 2:
418443
bboxes = [bboxes]
@@ -421,7 +446,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
421446
for bbox in bboxes:
422447
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
423448
expected_bboxes.append(
424-
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center)
449+
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_)
425450
)
426451
if len(expected_bboxes) > 1:
427452
expected_bboxes = torch.stack(expected_bboxes)
@@ -510,8 +535,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
510535
shear=(shear, shear),
511536
center=center,
512537
)
513-
if center is None:
514-
center = [s // 2 for s in mask.shape[-2:][::-1]]
538+
539+
center_ = center
540+
if center_ is None:
541+
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]
515542

516543
if mask.ndim < 4:
517544
masks = [mask]
@@ -520,7 +547,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
520547

521548
expected_masks = []
522549
for mask in masks:
523-
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center)
550+
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_)
524551
expected_masks.append(expected_mask)
525552
if len(expected_masks) > 1:
526553
expected_masks = torch.stack(expected_masks)
@@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
550577

551578

552579
@pytest.mark.parametrize("angle", range(-90, 90, 56))
553-
@pytest.mark.parametrize("expand", [True, False])
554-
@pytest.mark.parametrize("center", [None, (12, 14)])
580+
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
555581
def test_correctness_rotate_bounding_box(angle, expand, center):
556582
def _compute_expected_bbox(bbox, angle_, expand_, center_):
557583
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
@@ -620,16 +646,17 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
620646
center=center,
621647
)
622648

623-
if center is None:
624-
center = [s // 2 for s in bboxes_image_size[::-1]]
649+
center_ = center
650+
if center_ is None:
651+
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]
625652

626653
if bboxes.ndim < 2:
627654
bboxes = [bboxes]
628655

629656
expected_bboxes = []
630657
for bbox in bboxes:
631658
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
632-
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center))
659+
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_))
633660
if len(expected_bboxes) > 1:
634661
expected_bboxes = torch.stack(expected_bboxes)
635662
else:
@@ -638,7 +665,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
638665

639666

640667
@pytest.mark.parametrize("device", cpu_and_gpu())
641-
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress
668+
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
642669
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
643670
# Check transformation against known expected output
644671
image_size = (64, 64)
@@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
689716
)
690717

691718
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
719+
720+
721+
@pytest.mark.parametrize("angle", range(-90, 90, 37))
722+
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
723+
def test_correctness_rotate_segmentation_mask(angle, expand, center):
724+
def _compute_expected_mask(mask, angle_, expand_, center_):
725+
assert mask.ndim == 3 and mask.shape[0] == 1
726+
image_size = mask.shape[-2:]
727+
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
728+
inv_affine_matrix = np.linalg.inv(affine_matrix)
729+
730+
if expand_:
731+
# Pillow implementation on how to perform expand:
732+
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069
733+
height, width = image_size
734+
points = np.array(
735+
[
736+
[0.0, 0.0, 1.0],
737+
[0.0, 1.0 * height, 1.0],
738+
[1.0 * width, 1.0 * height, 1.0],
739+
[1.0 * width, 0.0, 1.0],
740+
]
741+
)
742+
new_points = points @ inv_affine_matrix.T
743+
min_vals = np.min(new_points, axis=0)[:2]
744+
max_vals = np.max(new_points, axis=0)[:2]
745+
cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4)
746+
cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4)
747+
new_width, new_height = (cmax - cmin).astype("int32").tolist()
748+
tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T
749+
750+
inv_affine_matrix[:2, 2] = tr[:2]
751+
image_size = [new_height, new_width]
752+
753+
inv_affine_matrix = inv_affine_matrix[:2, :]
754+
expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype)
755+
756+
for out_y in range(expected_mask.shape[1]):
757+
for out_x in range(expected_mask.shape[2]):
758+
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
759+
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
760+
in_x, in_y = input_pt[:2]
761+
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
762+
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
763+
return expected_mask.to(mask.device)
764+
765+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
766+
output_mask = F.rotate_segmentation_mask(
767+
mask,
768+
angle=angle,
769+
expand=expand,
770+
center=center,
771+
)
772+
773+
center_ = center
774+
if center_ is None:
775+
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]
776+
777+
if mask.ndim < 4:
778+
masks = [mask]
779+
else:
780+
masks = [m for m in mask]
781+
782+
expected_masks = []
783+
for mask in masks:
784+
expected_mask = _compute_expected_mask(mask, -angle, expand, center_)
785+
expected_masks.append(expected_mask)
786+
if len(expected_masks) > 1:
787+
expected_masks = torch.stack(expected_masks)
788+
else:
789+
expected_masks = expected_masks[0]
790+
torch.testing.assert_close(output_mask, expected_masks)
791+
792+
793+
@pytest.mark.parametrize("device", cpu_and_gpu())
794+
def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
795+
# Check transformation against known expected output and CPU/CUDA devices
796+
797+
# Create a fixed input segmentation mask with 2 square masks
798+
# in top-left, bottom-left corners
799+
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
800+
mask[0, 2:10, 2:10] = 1
801+
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
802+
803+
# Rotate 90 degrees
804+
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
805+
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
806+
torch.testing.assert_close(out_mask, expected_mask)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
rotate_bounding_box,
5757
rotate_image_tensor,
5858
rotate_image_pil,
59+
rotate_segmentation_mask,
5960
pad_image_tensor,
6061
pad_image_pil,
6162
pad_bounding_box,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def rotate_image_tensor(
324324
center_f = [0.0, 0.0]
325325
if center is not None:
326326
if expand:
327-
warnings.warn("The provided center argument is ignored if expand is True")
327+
warnings.warn("The provided center argument has no effect on the result if expand is True")
328328
else:
329329
_, height, width = get_dimensions_image_tensor(img)
330330
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
@@ -345,7 +345,7 @@ def rotate_image_pil(
345345
center: Optional[List[float]] = None,
346346
) -> PIL.Image.Image:
347347
if center is not None and expand:
348-
warnings.warn("The provided center argument is ignored if expand is True")
348+
warnings.warn("The provided center argument has no effect on the result if expand is True")
349349
center = None
350350

351351
return _FP.rotate(
@@ -361,6 +361,10 @@ def rotate_bounding_box(
361361
expand: bool = False,
362362
center: Optional[List[float]] = None,
363363
) -> torch.Tensor:
364+
if center is not None and expand:
365+
warnings.warn("The provided center argument has no effect on the result if expand is True")
366+
center = None
367+
364368
original_shape = bounding_box.shape
365369
bounding_box = convert_bounding_box_format(
366370
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
@@ -373,6 +377,21 @@ def rotate_bounding_box(
373377
).view(original_shape)
374378

375379

380+
def rotate_segmentation_mask(
381+
img: torch.Tensor,
382+
angle: float,
383+
expand: bool = False,
384+
center: Optional[List[float]] = None,
385+
) -> torch.Tensor:
386+
return rotate_image_tensor(
387+
img,
388+
angle=angle,
389+
expand=expand,
390+
interpolation=InterpolationMode.NEAREST,
391+
center=center,
392+
)
393+
394+
376395
pad_image_tensor = _FT.pad
377396
pad_image_pil = _FP.pad
378397

0 commit comments

Comments
 (0)