Skip to content

Commit 5bb997c

Browse files
authored
Rewrite draw_segmentation_masks and update gallery example to illustrate both instance and semantic segmentation models (#3824)
1 parent 32bccc5 commit 5bb997c

File tree

5 files changed

+376
-102
lines changed

5 files changed

+376
-102
lines changed

gallery/plot_visualization_utils.py

Lines changed: 245 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def show(imgs):
2424
imgs = [imgs]
2525
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
2626
for i, img in enumerate(imgs):
27-
img = F.to_pil_image(img.to('cpu'))
27+
img = img.detach()
28+
img = F.to_pil_image(img)
2829
axs[0, i].imshow(np.asarray(img))
2930
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
3031

@@ -50,9 +51,8 @@ def show(imgs):
5051
# Visualizing bounding boxes
5152
# --------------------------
5253
# We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
53-
# image. We can set the colors, labels, width as well as font and font size !
54-
# The boxes are in ``(xmin, ymin, xmax, ymax)`` format
55-
# from torchvision.utils import draw_bounding_boxes
54+
# image. We can set the colors, labels, width as well as font and font size.
55+
# The boxes are in ``(xmin, ymin, xmax, ymax)`` format.
5656

5757
from torchvision.utils import draw_bounding_boxes
5858

@@ -74,9 +74,8 @@ def show(imgs):
7474
from torchvision.transforms.functional import convert_image_dtype
7575

7676

77-
dog1_float = convert_image_dtype(dog1_int, dtype=torch.float)
78-
dog2_float = convert_image_dtype(dog2_int, dtype=torch.float)
79-
batch = torch.stack([dog1_float, dog2_float])
77+
batch_int = torch.stack([dog1_int, dog2_int])
78+
batch = convert_image_dtype(batch_int, dtype=torch.float)
8079

8180
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
8281
model = model.eval()
@@ -91,41 +90,263 @@ def show(imgs):
9190
threshold = .8
9291
dogs_with_boxes = [
9392
draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > threshold], width=4)
94-
for dog_int, output in zip((dog1_int, dog2_int), outputs)
93+
for dog_int, output in zip(batch_int, outputs)
9594
]
9695
show(dogs_with_boxes)
9796

9897
#####################################
9998
# Visualizing segmentation masks
10099
# ------------------------------
101100
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
102-
# draw segmentation amasks on images. We can set the colors as well as
103-
# transparency of masks.
101+
# draw segmentation masks on images. Semantic segmentation and instance
102+
# segmentation models have different outputs, so we will treat each
103+
# independently.
104104
#
105-
# Here is demo with torchvision's FCN Resnet-50, loaded with
106-
# :func:`~torchvision.models.segmentation.fcn_resnet50`.
107-
# You can also try using
108-
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`)
109-
# or lraspp mobilenet models
105+
# Semantic segmentation models
106+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
107+
#
108+
# We will see how to use it with torchvision's FCN Resnet-50, loaded with
109+
# :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using
110+
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or
111+
# lraspp mobilenet models
110112
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
111113
#
112-
# Like :func:`~torchvision.utils.draw_bounding_boxes`,
113-
# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image
114-
# of dtype `uint8`.
114+
# Let's start by looking at the ouput of the model. Remember that in general,
115+
# images must be normalized before they're passed to a semantic segmentation
116+
# model.
115117

116118
from torchvision.models.segmentation import fcn_resnet50
117-
from torchvision.utils import draw_segmentation_masks
118119

119120

120121
model = fcn_resnet50(pretrained=True, progress=False)
121122
model = model.eval()
122123

123-
# The model expects the batch to be normalized
124-
batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
125-
outputs = model(batch)
124+
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
125+
output = model(normalized_batch)['out']
126+
print(output.shape, output.min().item(), output.max().item())
127+
128+
#####################################
129+
# As we can see above, the output of the segmentation model is a tensor of shape
130+
# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and
131+
# we can normalize them into ``[0, 1]`` by using a softmax. After the softmax,
132+
# we can interpret each value as a probability indicating how likely a given
133+
# pixel is to belong to a given class.
134+
#
135+
# Let's plot the masks that have been detected for the dog class and for the
136+
# boat class:
137+
138+
sem_classes = [
139+
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
140+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
141+
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
142+
]
143+
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
144+
145+
normalized_masks = torch.nn.functional.softmax(output, dim=1)
146+
147+
dog_and_boat_masks = [
148+
normalized_masks[img_idx, sem_class_to_idx[cls]]
149+
for img_idx in range(batch.shape[0])
150+
for cls in ('dog', 'boat')
151+
]
152+
153+
show(dog_and_boat_masks)
154+
155+
#####################################
156+
# As expected, the model is confident about the dog class, but not so much for
157+
# the boat class.
158+
#
159+
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
160+
# plots those masks on top of the original image. This function expects the
161+
# masks to be boolean masks, but our masks above contain probabilities in ``[0,
162+
# 1]``. To get boolean masks, we can do the following:
163+
164+
class_dim = 1
165+
boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog'])
166+
print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
167+
show([m.float() for m in boolean_dog_masks])
168+
169+
170+
#####################################
171+
# The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you
172+
# can read it as the following query: "For which pixels is 'dog' the most likely
173+
# class?"
174+
#
175+
# .. note::
176+
# While we're using the ``normalized_masks`` here, we would have
177+
# gotten the same result by using the non-normalized scores of the model
178+
# directly (as the softmax operation preserves the order).
179+
#
180+
# Now that we have boolean masks, we can use them with
181+
# :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the
182+
# original images:
183+
184+
from torchvision.utils import draw_segmentation_masks
185+
186+
dogs_with_masks = [
187+
draw_segmentation_masks(img, masks=mask, alpha=0.7)
188+
for img, mask in zip(batch_int, boolean_dog_masks)
189+
]
190+
show(dogs_with_masks)
191+
192+
#####################################
193+
# We can plot more than one mask per image! Remember that the model returned as
194+
# many masks as there are classes. Let's ask the same query as above, but this
195+
# time for *all* classes, not just the dog class: "For each pixel and each class
196+
# C, is class C the most most likely class?"
197+
#
198+
# This one is a bit more involved, so we'll first show how to do it with a
199+
# single image, and then we'll generalize to the batch
200+
201+
num_classes = normalized_masks.shape[1]
202+
dog1_masks = normalized_masks[0]
203+
class_dim = 0
204+
dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None]
205+
206+
print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}")
207+
print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}")
208+
209+
dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6)
210+
show(dog_with_all_masks)
211+
212+
#####################################
213+
# We can see in the image above that only 2 masks were drawn: the mask for the
214+
# background and the mask for the dog. This is because the model thinks that
215+
# only these 2 classes are the most likely ones across all the pixels. If the
216+
# model had detected another class as the most likely among other pixels, we
217+
# would have seen its mask above.
218+
#
219+
# Removing the background mask is as simple as passing
220+
# ``masks=dog1_all_classes_masks[1:]``, because the background class is the
221+
# class with index 0.
222+
#
223+
# Let's now do the same but for an entire batch of images. The code is similar
224+
# but involves a bit more juggling with the dimensions.
225+
226+
class_dim = 1
227+
all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None]
228+
print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}")
229+
# The first dimension is the classes now, so we need to swap it
230+
all_classes_masks = all_classes_masks.swapaxes(0, 1)
126231

127232
dogs_with_masks = [
128-
draw_segmentation_masks(dog_int, masks=masks, alpha=0.6)
129-
for dog_int, masks in zip((dog1_int, dog2_int), outputs['out'])
233+
draw_segmentation_masks(img, masks=mask, alpha=.6)
234+
for img, mask in zip(batch_int, all_classes_masks)
130235
]
131236
show(dogs_with_masks)
237+
238+
239+
#####################################
240+
# Instance segmentation models
241+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
242+
#
243+
# Instance segmentation models have a significantly different output from the
244+
# semantic segmentation models. We will see here how to plot the masks for such
245+
# models. Let's start by analyzing the output of a Mask-RCNN model. Note that
246+
# these models don't require the images to be normalized, so we don't need to
247+
# use the normalized batch.
248+
249+
from torchvision.models.detection import maskrcnn_resnet50_fpn
250+
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False)
251+
model = model.eval()
252+
253+
output = model(batch)
254+
print(output)
255+
256+
#####################################
257+
# Let's break this down. For each image in the batch, the model outputs some
258+
# detections (or instances). The number of detection varies for each input
259+
# image. Each instance is described by its bounding box, its label, its score
260+
# and its mask.
261+
#
262+
# The way the output is organized is as follows: the output is a list of length
263+
# ``batch_size``. Each entry in the list corresponds to an input image, and it
264+
# is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value
265+
# associated to those keys has ``num_instances`` elements in it. In our case
266+
# above there are 3 instances detected in the first image, and 2 instances in
267+
# the second one.
268+
#
269+
# The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes`
270+
# as above, but here we're more interested in the masks. These masks are quite
271+
# different from the masks that we saw above for the semantic segmentation
272+
# models.
273+
274+
dog1_output = output[0]
275+
dog1_masks = dog1_output['masks']
276+
print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
277+
f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")
278+
279+
#####################################
280+
# Here the masks corresponds to probabilities indicating, for each pixel, how
281+
# likely it is to belong to the predicted label of that instance. Those
282+
# predicted labels correspond to the 'labels' element in the same output dict.
283+
# Let's see which labels were predicted for the instances of the first image.
284+
285+
inst_classes = [
286+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
287+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
288+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
289+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
290+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
291+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
292+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
293+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
294+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
295+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
296+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
297+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
298+
]
299+
300+
inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)}
301+
302+
print("For the first dog, the following instances were detected:")
303+
print([inst_classes[label] for label in dog1_output['labels']])
304+
305+
#####################################
306+
# Interestingly, the model detects two persons in the image. Let's go ahead and
307+
# plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks`
308+
# expects boolean masks, we need to convert those probabilities into boolean
309+
# values. Remember that the semantic of those masks is "How likely is this pixel
310+
# to belong to the predicted class?". As a result, a natural way of converting
311+
# those masks into boolean values is to threshold them with the 0.5 probability
312+
# (one could also choose a different threshold).
313+
314+
proba_threshold = 0.5
315+
dog1_bool_masks = dog1_output['masks'] > proba_threshold
316+
print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}")
317+
318+
# There's an extra dimension (1) to the masks. We need to remove it
319+
dog1_bool_masks = dog1_bool_masks.squeeze(1)
320+
321+
show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9))
322+
323+
#####################################
324+
# The model seems to have properly detected the dog, but it also confused trees
325+
# with people. Looking more closely at the scores will help us plotting more
326+
# relevant masks:
327+
328+
print(dog1_output['scores'])
329+
330+
#####################################
331+
# Clearly the model is less confident about the dog detection than it is about
332+
# the people detections. That's good news. When plotting the masks, we can ask
333+
# for only those that have a good score. Let's use a score threshold of .75
334+
# here, and also plot the masks of the second dog.
335+
336+
score_threshold = .75
337+
338+
boolean_masks = [
339+
out['masks'][out['scores'] > score_threshold] > proba_threshold
340+
for out in output
341+
]
342+
343+
dogs_with_masks = [
344+
draw_segmentation_masks(img, mask.squeeze(1))
345+
for img, mask in zip(batch_int, boolean_masks)
346+
]
347+
show(dogs_with_masks)
348+
349+
#####################################
350+
# The two 'people' masks in the first image where not selected because they have
351+
# a lower score than the score threshold. Similarly in the second image, the
352+
# instance with class 15 (which corresponds to 'bench') was not selected.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)