Skip to content

Commit f0422e7

Browse files
0x00b1datumboxoke-aditya
authored
masks_to_bounding_boxes op (#4290)
* ops.masks_to_bounding_boxes * test fixtures * unit test * ignore lint e201 and e202 for in-lined matrix * ignore e121 and e241 linting rules for in-lined matrix * draft gallery example text * removed type annotations from pytest fixtures * inlined fixture * renamed masks_to_bounding_boxes to masks_to_boxes * reformat inline array * import cleanup * moved masks_to_boxes into boxes module * docstring cleanup * updated docstring * fix formatting issue * gallery example * use torch * use torch * use torch * use torch * updated docs and test * cleanup * updated import * use torch * Update gallery/plot_repurposing_annotations.py Co-authored-by: Aditya Oke <[email protected]> * Update gallery/plot_repurposing_annotations.py Co-authored-by: Aditya Oke <[email protected]> * Update gallery/plot_repurposing_annotations.py Co-authored-by: Aditya Oke <[email protected]> * Autodoc * use torch instead of numpy in tests * fix build_docs failure * Closing quotes. Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Aditya Oke <[email protected]>
1 parent 8a83cf2 commit f0422e7

File tree

8 files changed

+154
-11
lines changed

8 files changed

+154
-11
lines changed

docs/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
sphinx==3.5.4
2-
sphinx-gallery>=0.9.0
3-
sphinx-copybutton>=0.3.1
41
matplotlib
52
numpy
3+
sphinx-copybutton>=0.3.1
4+
sphinx-gallery>=0.9.0
5+
sphinx==3.5.4
66
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

docs/source/ops.rst

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@ torchvision.ops
99
All operators have native support for TorchScript.
1010

1111

12-
.. autofunction:: nms
1312
.. autofunction:: batched_nms
14-
.. autofunction:: remove_small_boxes
15-
.. autofunction:: clip_boxes_to_image
16-
.. autofunction:: box_convert
1713
.. autofunction:: box_area
14+
.. autofunction:: box_convert
1815
.. autofunction:: box_iou
16+
.. autofunction:: clip_boxes_to_image
17+
.. autofunction:: deform_conv2d
1918
.. autofunction:: generalized_box_iou
20-
.. autofunction:: roi_align
19+
.. autofunction:: masks_to_boxes
20+
.. autofunction:: nms
2121
.. autofunction:: ps_roi_align
22-
.. autofunction:: roi_pool
2322
.. autofunction:: ps_roi_pool
24-
.. autofunction:: deform_conv2d
23+
.. autofunction:: remove_small_boxes
24+
.. autofunction:: roi_align
25+
.. autofunction:: roi_pool
2526
.. autofunction:: sigmoid_focal_loss
2627
.. autofunction:: stochastic_depth
2728

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
=======================
3+
Repurposing annotations
4+
=======================
5+
6+
The following example illustrates the operations available in the torchvision.ops module for repurposing object
7+
localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation
8+
methods into bounding boxes used by object detection methods).
9+
"""
10+
import os.path
11+
12+
import PIL.Image
13+
import matplotlib.patches
14+
import matplotlib.pyplot
15+
import numpy
16+
import torch
17+
from torchvision.ops import masks_to_boxes
18+
19+
ASSETS_DIRECTORY = "../test/assets"
20+
21+
matplotlib.pyplot.rcParams["savefig.bbox"] = "tight"
22+
23+
####################################
24+
# Masks
25+
# -----
26+
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
27+
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:
28+
#
29+
# (objects, height, width)
30+
#
31+
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
32+
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
33+
# of your masks annotation has the following shape:
34+
#
35+
# (4, 224, 224).
36+
#
37+
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
38+
# localization tasks.
39+
#
40+
# Masks to bounding boxes
41+
# ----------------------------------------
42+
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be
43+
# used in methods like Faster RCNN and YOLO.
44+
45+
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
46+
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)
47+
48+
for index in range(image.n_frames):
49+
image.seek(index)
50+
51+
frame = numpy.array(image)
52+
53+
masks[index] = torch.tensor(frame)
54+
55+
bounding_boxes = masks_to_boxes(masks)
56+
57+
figure = matplotlib.pyplot.figure()
58+
59+
a = figure.add_subplot(121)
60+
b = figure.add_subplot(122)
61+
62+
labeled_image = torch.sum(masks, 0)
63+
64+
a.imshow(labeled_image)
65+
b.imshow(labeled_image)
66+
67+
for bounding_box in bounding_boxes:
68+
x0, y0, x1, y1 = bounding_box
69+
70+
rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none")
71+
72+
b.add_patch(rectangle)
73+
74+
a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
75+
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

test/assets/labeled_image.png

896 Bytes
Loading

test/assets/masks.tiff

344 KB
Binary file not shown.

test/test_masks_to_boxes.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os.path
2+
3+
import PIL.Image
4+
import numpy
5+
import torch
6+
7+
from torchvision.ops import masks_to_boxes
8+
9+
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
10+
11+
12+
def test_masks_to_boxes():
13+
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
14+
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)
15+
16+
for index in range(image.n_frames):
17+
image.seek(index)
18+
19+
frame = numpy.array(image)
20+
21+
masks[index] = torch.tensor(frame)
22+
23+
expected = torch.tensor(
24+
[[127, 2, 165, 40],
25+
[2, 50, 44, 92],
26+
[56, 63, 98, 100],
27+
[139, 68, 175, 104],
28+
[160, 112, 198, 145],
29+
[49, 138, 99, 182],
30+
[108, 148, 152, 213]],
31+
dtype=torch.int32
32+
)
33+
34+
torch.testing.assert_close(masks_to_boxes(masks), expected)

torchvision/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
1+
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \
2+
masks_to_boxes
23
from .boxes import box_convert
34
from .deform_conv import deform_conv2d, DeformConv2d
45
from .roi_align import roi_align, RoIAlign

torchvision/ops/boxes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
297297
areai = whi[:, :, 0] * whi[:, :, 1]
298298

299299
return iou - (areai - union) / areai
300+
301+
302+
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
303+
"""
304+
Compute the bounding boxes around the provided masks
305+
306+
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
307+
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
308+
309+
Args:
310+
masks (Tensor[N, H, W]): masks to transform where N is the number of
311+
masks and (H, W) are the spatial dimensions.
312+
313+
Returns:
314+
Tensor[N, 4]: bounding boxes
315+
"""
316+
if masks.numel() == 0:
317+
return torch.zeros((0, 4))
318+
319+
n = masks.shape[0]
320+
321+
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)
322+
323+
for index, mask in enumerate(masks):
324+
y, x = torch.where(masks[index] != 0)
325+
326+
bounding_boxes[index, 0] = torch.min(x)
327+
bounding_boxes[index, 1] = torch.min(y)
328+
bounding_boxes[index, 2] = torch.max(x)
329+
bounding_boxes[index, 3] = torch.max(y)
330+
331+
return bounding_boxes

0 commit comments

Comments
 (0)