Skip to content

Commit c2a4e9f

Browse files
committed
Rewrite test and fix masks_to_boxes implementation (#4469)
Co-authored-by: Nicolas Hug <[email protected]> [ghstack-poisoned]
1 parent 5acf580 commit c2a4e9f

File tree

3 files changed

+40
-40
lines changed

3 files changed

+40
-40
lines changed

test/test_masks_to_boxes.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

test/test_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55

66
import numpy as np
7+
import os
78

9+
from PIL import Image
810
import torch
911
from functools import lru_cache
1012
from torch import Tensor
@@ -1000,6 +1002,38 @@ def gen_iou_check(box, expected, tolerance=1e-4):
10001002
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
10011003

10021004

1005+
class TestMasksToBoxes:
1006+
def test_masks_box(self):
1007+
def masks_box_check(masks, expected, tolerance=1e-4):
1008+
out = ops.masks_to_boxes(masks)
1009+
assert out.dtype == torch.float
1010+
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
1011+
1012+
# Check for int type boxes.
1013+
def _get_image():
1014+
assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1015+
mask_path = os.path.join(assets_directory, "masks.tiff")
1016+
image = Image.open(mask_path)
1017+
return image
1018+
1019+
def _create_masks(image, masks):
1020+
for index in range(image.n_frames):
1021+
image.seek(index)
1022+
frame = np.array(image)
1023+
masks[index] = torch.tensor(frame)
1024+
1025+
return masks
1026+
1027+
expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104],
1028+
[160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float)
1029+
1030+
image = _get_image()
1031+
for dtype in [torch.float16, torch.float32, torch.float64]:
1032+
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
1033+
masks = _create_masks(image, masks)
1034+
masks_box_check(masks, expected)
1035+
1036+
10031037
class TestStochasticDepth:
10041038
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
10051039
@pytest.mark.parametrize('mode', ["batch", "row"])

torchvision/ops/boxes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,24 +301,24 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
301301

302302
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
303303
"""
304-
Compute the bounding boxes around the provided masks
304+
Compute the bounding boxes around the provided masks.
305305
306-
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
306+
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
307307
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
308308
309309
Args:
310-
masks (Tensor[N, H, W]): masks to transform where N is the number of
311-
masks and (H, W) are the spatial dimensions.
310+
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
311+
and (H, W) are the spatial dimensions.
312312
313313
Returns:
314314
Tensor[N, 4]: bounding boxes
315315
"""
316316
if masks.numel() == 0:
317-
return torch.zeros((0, 4))
317+
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
318318

319319
n = masks.shape[0]
320320

321-
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)
321+
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
322322

323323
for index, mask in enumerate(masks):
324324
y, x = torch.where(masks[index] != 0)

0 commit comments

Comments
 (0)