Skip to content

Commit 1e19d73

Browse files
NicolasHugpmeier
andauthored
Add SanitizeBoundingBoxes transform (#7246)
Co-authored-by: Philip Meier <[email protected]>
1 parent c5e9a10 commit 1e19d73

File tree

3 files changed

+224
-21
lines changed

3 files changed

+224
-21
lines changed

test/test_prototype_transforms.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import pathlib
3+
import random
34
import re
45
import warnings
56
from collections import defaultdict
@@ -2355,3 +2356,118 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
23552356

23562357
out["label"] = torch.tensor(out["label"])
23572358
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
2359+
2360+
2361+
@pytest.mark.parametrize("min_size", (1, 10))
2362+
@pytest.mark.parametrize(
2363+
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
2364+
)
2365+
def test_sanitize_bounding_boxes(min_size, labels_getter):
2366+
H, W = 256, 128
2367+
2368+
boxes_and_validity = [
2369+
([0, 1, 10, 1], False), # Y1 == Y2
2370+
([0, 1, 0, 20], False), # X1 == X2
2371+
([0, 0, min_size - 1, 10], False), # H < min_size
2372+
([0, 0, 10, min_size - 1], False), # W < min_size
2373+
([0, 0, 10, H + 1], False), # Y2 > H
2374+
([0, 0, W + 1, 10], False), # X2 > W
2375+
([-1, 1, 10, 20], False), # any < 0
2376+
([0, 0, -1, 20], False), # any < 0
2377+
([0, 0, -10, -1], False), # any < 0
2378+
([0, 0, min_size, 10], True), # H < min_size
2379+
([0, 0, 10, min_size], True), # W < min_size
2380+
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
2381+
([1, 1, 30, 20], True),
2382+
([0, 0, 10, 10], True),
2383+
([1, 1, 30, 20], True),
2384+
]
2385+
2386+
random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
2387+
boxes, is_valid_mask = zip(*boxes_and_validity)
2388+
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
2389+
2390+
boxes = torch.tensor(boxes)
2391+
labels = torch.arange(boxes.shape[-2])
2392+
2393+
boxes = datapoints.BoundingBox(
2394+
boxes,
2395+
format=datapoints.BoundingBoxFormat.XYXY,
2396+
spatial_size=(H, W),
2397+
)
2398+
2399+
sample = {
2400+
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
2401+
"labels": labels,
2402+
"boxes": boxes,
2403+
"whatever": torch.rand(10),
2404+
"None": None,
2405+
}
2406+
2407+
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
2408+
2409+
assert out["image"] is sample["image"]
2410+
assert out["whatever"] is sample["whatever"]
2411+
2412+
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
2413+
assert out["labels"] is sample["labels"]
2414+
else:
2415+
assert isinstance(out["labels"], torch.Tensor)
2416+
assert out["boxes"].shape[:-1] == out["labels"].shape
2417+
# This works because we conveniently set labels to arange(num_boxes)
2418+
assert out["labels"].tolist() == valid_indices
2419+
2420+
2421+
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
2422+
def test_sanitize_bounding_boxes_default_heuristic(key):
2423+
labels = torch.arange(10)
2424+
d = {key: labels}
2425+
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
2426+
2427+
if key.lower() != "labels":
2428+
# If "labels" is in the dict (case-insensitive),
2429+
# it takes precedence over other keys which would otherwise be a match
2430+
d = {key: "something_else", "labels": labels}
2431+
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
2432+
2433+
2434+
def test_sanitize_bounding_boxes_errors():
2435+
2436+
good_bbox = datapoints.BoundingBox(
2437+
[[0, 0, 10, 10]],
2438+
format=datapoints.BoundingBoxFormat.XYXY,
2439+
spatial_size=(20, 20),
2440+
)
2441+
2442+
with pytest.raises(ValueError, match="min_size must be >= 1"):
2443+
transforms.SanitizeBoundingBoxes(min_size=0)
2444+
with pytest.raises(ValueError, match="labels_getter should either be a str"):
2445+
transforms.SanitizeBoundingBoxes(labels_getter=12)
2446+
2447+
with pytest.raises(ValueError, match="Could not infer where the labels are"):
2448+
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
2449+
transforms.SanitizeBoundingBoxes()(bad_labels_key)
2450+
2451+
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
2452+
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
2453+
transforms.SanitizeBoundingBoxes()(not_a_dict)
2454+
2455+
with pytest.raises(ValueError, match="must be a tensor"):
2456+
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
2457+
transforms.SanitizeBoundingBoxes()(not_a_tensor)
2458+
2459+
with pytest.raises(ValueError, match="Number of boxes"):
2460+
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
2461+
transforms.SanitizeBoundingBoxes()(different_sizes)
2462+
2463+
with pytest.raises(ValueError, match="boxes must be of shape"):
2464+
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
2465+
[
2466+
[[0, 0, 10, 10]],
2467+
[[0, 0, 10, 10]],
2468+
],
2469+
format=datapoints.BoundingBoxFormat.XYXY,
2470+
spatial_size=(20, 20),
2471+
)
2472+
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
2473+
transforms.SanitizeBoundingBoxes()(different_sizes)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
LinearTransformation,
5050
Normalize,
5151
PermuteDimensions,
52-
RemoveSmallBoundingBoxes,
52+
SanitizeBoundingBoxes,
5353
ToDtype,
5454
TransposeDimensions,
5555
)

torchvision/prototype/transforms/_misc.py

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import collections
12
import warnings
2-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
3+
from contextlib import suppress
4+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
35

46
import PIL.Image
57

68
import torch
9+
from torch.utils._pytree import tree_flatten, tree_unflatten
710

811
from torchvision import transforms as _transforms
9-
from torchvision.ops import remove_small_boxes
1012
from torchvision.prototype import datapoints
1113
from torchvision.prototype.transforms import functional as F, Transform
1214

@@ -225,28 +227,113 @@ def _transform(
225227
return inpt.transpose(*dims)
226228

227229

228-
class RemoveSmallBoundingBoxes(Transform):
229-
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
230+
class SanitizeBoundingBoxes(Transform):
231+
# This removes boxes and their corresponding labels:
232+
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
233+
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
230234

231-
def __init__(self, min_size: float = 1.0) -> None:
235+
def __init__(
236+
self,
237+
min_size: float = 1.0,
238+
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
239+
) -> None:
232240
super().__init__()
241+
242+
if min_size < 1:
243+
raise ValueError(f"min_size must be >= 1, got {min_size}.")
233244
self.min_size = min_size
234245

235-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
236-
bounding_box = query_bounding_box(flat_inputs)
237-
238-
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
239-
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
240-
# format,we need to convert first just to afterwards compute the width and height again, although they were
241-
# there in the first place for these formats.
242-
bounding_box = F.convert_format_bounding_box(
243-
bounding_box.as_subclass(torch.Tensor),
244-
old_format=bounding_box.format,
245-
new_format=datapoints.BoundingBoxFormat.XYXY,
246-
)
247-
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
246+
self.labels_getter = labels_getter
247+
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
248+
if labels_getter == "default":
249+
self._labels_getter = self._find_labels_default_heuristic
250+
elif callable(labels_getter):
251+
self._labels_getter = labels_getter
252+
elif isinstance(labels_getter, str):
253+
self._labels_getter = lambda inputs: inputs[labels_getter]
254+
elif labels_getter is None:
255+
self._labels_getter = None
256+
else:
257+
raise ValueError(
258+
"labels_getter should either be a str, callable, or 'default'. "
259+
f"Got {labels_getter} of type {type(labels_getter)}."
260+
)
261+
262+
@staticmethod
263+
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
264+
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
265+
# Returns None if nothing is found
266+
candidate_key = None
267+
with suppress(StopIteration):
268+
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
269+
if candidate_key is None:
270+
with suppress(StopIteration):
271+
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
272+
if candidate_key is None:
273+
raise ValueError(
274+
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
275+
"If there are no samples and it is by design, pass labels_getter=None."
276+
)
277+
return inputs[candidate_key]
278+
279+
def forward(self, *inputs: Any) -> Any:
280+
inputs = inputs if len(inputs) > 1 else inputs[0]
281+
282+
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
283+
raise ValueError(
284+
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
285+
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
286+
)
287+
288+
if self._labels_getter is None:
289+
labels = None
290+
else:
291+
labels = self._labels_getter(inputs)
292+
if labels is not None and not isinstance(labels, torch.Tensor):
293+
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
248294

249-
return dict(valid_indices=valid_indices)
295+
flat_inputs, spec = tree_flatten(inputs)
296+
# TODO: this enforces one single BoundingBox entry.
297+
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
298+
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
299+
boxes = query_bounding_box(flat_inputs)
300+
301+
if boxes.ndim != 2:
302+
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
303+
304+
if labels is not None and boxes.shape[0] != labels.shape[0]:
305+
raise ValueError(
306+
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
307+
)
308+
309+
boxes = cast(
310+
datapoints.BoundingBox,
311+
F.convert_format_bounding_box(
312+
boxes,
313+
new_format=datapoints.BoundingBoxFormat.XYXY,
314+
),
315+
)
316+
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
317+
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
318+
# TODO: Do we really need to check for out of bounds here? All
319+
# transforms should be clamping anyway, so this should never happen?
320+
image_h, image_w = boxes.spatial_size
321+
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
322+
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
323+
324+
params = dict(mask=mask, labels=labels)
325+
flat_outputs = [
326+
# Even-though it may look like we're transforming all inputs, we don't:
327+
# _transform() will only care about BoundingBoxes and the labels
328+
self._transform(inpt, params)
329+
for inpt in flat_inputs
330+
]
331+
332+
return tree_unflatten(flat_outputs, spec)
250333

251334
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
252-
return inpt.wrap_like(inpt, inpt[params["valid_indices"]])
335+
336+
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
337+
inpt = inpt[params["mask"]]
338+
339+
return inpt

0 commit comments

Comments
 (0)