diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 5fd4b75e59d..f4f03bdb298 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -1,3 +1,5 @@ +.. _ops: + torchvision.ops =============== diff --git a/docs/source/utils.rst b/docs/source/utils.rst index acaf785d817..b0a2d743d4e 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -1,3 +1,5 @@ +.. _utils: + torchvision.utils ================= diff --git a/gallery/assets/FudanPed00054.png b/gallery/assets/FudanPed00054.png new file mode 100644 index 00000000000..951682abb93 Binary files /dev/null and b/gallery/assets/FudanPed00054.png differ diff --git a/gallery/assets/FudanPed00054_mask.png b/gallery/assets/FudanPed00054_mask.png new file mode 100644 index 00000000000..4d5aa4e4020 Binary files /dev/null and b/gallery/assets/FudanPed00054_mask.png differ diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index 2decefcc815..012bd9857ef 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -1,24 +1,39 @@ """ -======================= -Repurposing annotations -======================= - -The following example illustrates the operations available in the torchvision.ops module for repurposing object -localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation +===================================== +Repurposing masks into bounding boxes +===================================== + +The following example illustrates the operations available +the :ref:`torchvision.ops ` module for repurposing +segmentation masks into object localization annotations for different tasks +(e.g. transforming masks used by instance and panoptic segmentation methods into bounding boxes used by object detection methods). """ -import os.path -import PIL.Image -import matplotlib.patches -import matplotlib.pyplot -import numpy + +import os +import numpy as np import torch -from torchvision.ops import masks_to_boxes +import matplotlib.pyplot as plt + +import torchvision.transforms.functional as F + + +ASSETS_DIRECTORY = "assets" -ASSETS_DIRECTORY = "../test/assets" +plt.rcParams["savefig.bbox"] = "tight" + + +def show(imgs): + if not isinstance(imgs, list): + imgs = [imgs] + fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) + for i, img in enumerate(imgs): + img = img.detach() + img = F.to_pil_image(img) + axs[0, i].imshow(np.asarray(img)) + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) -matplotlib.pyplot.rcParams["savefig.bbox"] = "tight" #################################### # Masks @@ -26,9 +41,9 @@ # In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, # as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape: # -# (objects, height, width) +# (num_objects, height, width) # -# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly +# Where num_objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly # one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape # of your masks annotation has the following shape: # @@ -36,40 +51,154 @@ # # A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object # localization tasks. -# -# Masks to bounding boxes -# ---------------------------------------- -# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be -# used in methods like Faster RCNN and YOLO. -with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: - masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) +#################################### +# Converting Masks to Bounding Boxes +# ----------------------------------------------- +# For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to +# transform masks into bounding boxes that can be +# used as input to detection models such as FasterRCNN and RetinaNet. +# We will take images and masks from the `PenFudan Dataset `_. + + +from torchvision.io import read_image + +img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png") +mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png") +img = read_image(img_path) +mask = read_image(mask_path) + + +######################### +# Here the masks are represented as a PNG Image, with floating point values. +# Each pixel is encoded as different colors, with 0 being background. +# Notice that the spatial dimensions of image and mask match. + +print(mask.size()) +print(img.size()) +print(mask) + +############################ + +# We get the unique colors, as these would be the object ids. +obj_ids = torch.unique(mask) + +# first id is the background, so remove it. +obj_ids = obj_ids[1:] + +# split the color-encoded mask into a set of boolean masks. +# Note that this snippet would work as well if the masks were float values instead of ints. +masks = mask == obj_ids[:, None, None] + +######################## +# Now the masks are a boolean tensor. +# The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image. +# The other two dimensions are height and width, which are equal to the dimensions of the image. +# For each instance, the boolean tensors represent if the particular pixel +# belongs to the segmentation mask of the image. + +print(masks.size()) +print(masks) + +#################################### +# Let us visualize an image and plot its corresponding segmentation masks. +# We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks. + +from torchvision.utils import draw_segmentation_masks + +drawn_masks = [] +for mask in masks: + drawn_masks.append(draw_segmentation_masks(img, mask, alpha=0.8, colors="blue")) + +show(drawn_masks) + +#################################### +# To convert the boolean masks into bounding boxes. +# We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module +# It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format. + +from torchvision.ops import masks_to_boxes + +boxes = masks_to_boxes(masks) +print(boxes.size()) +print(boxes) + +#################################### +# As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format. +# These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility +# provided in :ref:`torchvision.utils `. + +from torchvision.utils import draw_bounding_boxes + +drawn_boxes = draw_bounding_boxes(img, boxes, colors="red") +show(drawn_boxes) + +################################### +# These boxes can now directly be used by detection models in torchvision. +# Here is demo with a Faster R-CNN model loaded from +# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` + +from torchvision.models.detection import fasterrcnn_resnet50_fpn + +model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +print(img.size()) + +img = F.convert_image_dtype(img, torch.float) +target = {} +target["boxes"] = boxes +target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) +detection_outputs = model(img.unsqueeze(0), [target]) + + +#################################### +# Converting Segmentation Dataset to Detection Dataset +# ---------------------------------------------------- +# +# With this utility it becomes very simple to convert a segmentation dataset to a detection dataset. +# With this we can now use a segmentation dataset to train a detection model. +# One can similarly convert panoptic dataset to detection dataset. +# Here is an example where we re-purpose the dataset from the +# `PenFudan Detection Tutorial `_. - for index in range(image.n_frames): - image.seek(index) +class SegmentationToDetectionDataset(torch.utils.data.Dataset): + def __init__(self, root, transforms): + self.root = root + self.transforms = transforms + # load all image files, sorting them to + # ensure that they are aligned + self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages")))) + self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks")))) - frame = numpy.array(image) + def __getitem__(self, idx): + # load images and masks + img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) + mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) - masks[index] = torch.tensor(frame) + img = read_image(img_path) + mask = read_image(mask_path) -bounding_boxes = masks_to_boxes(masks) + img = F.convert_image_dtype(img, dtype=torch.float) + mask = F.convert_image_dtype(mask, dtype=torch.float) -figure = matplotlib.pyplot.figure() + # We get the unique colors, as these would be the object ids. + obj_ids = torch.unique(mask) -a = figure.add_subplot(121) -b = figure.add_subplot(122) + # first id is the background, so remove it. + obj_ids = obj_ids[1:] -labeled_image = torch.sum(masks, 0) + # split the color-encoded mask into a set of boolean masks. + masks = mask == obj_ids[:, None, None] -a.imshow(labeled_image) -b.imshow(labeled_image) + boxes = masks_to_boxes(masks) -for bounding_box in bounding_boxes: - x0, y0, x1, y1 = bounding_box + # there is only one class + labels = torch.ones((masks.shape[0],), dtype=torch.int64) - rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none") + target = {} + target["boxes"] = boxes + target["labels"] = labels - b.add_patch(rectangle) + if self.transforms is not None: + img, target = self.transforms(img, target) -a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) -b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + return img, target