Skip to content

Add SanitizeBoundingBoxes transform #7246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import pathlib
import random
import re
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -2355,3 +2356,118 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):

out["label"] = torch.tensor(out["label"])
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes


@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
def test_sanitize_bounding_boxes(min_size, labels_getter):
H, W = 256, 128

boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
([0, 0, min_size - 1, 10], False), # H < min_size
([0, 0, 10, min_size - 1], False), # W < min_size
([0, 0, 10, H + 1], False), # Y2 > H
([0, 0, W + 1, 10], False), # X2 > W
([-1, 1, 10, 20], False), # any < 0
([0, 0, -1, 20], False), # any < 0
([0, 0, -10, -1], False), # any < 0
([0, 0, min_size, 10], True), # H < min_size
([0, 0, 10, min_size], True), # W < min_size
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
([1, 1, 30, 20], True),
([0, 0, 10, 10], True),
([1, 1, 30, 20], True),
]

random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, is_valid_mask = zip(*boxes_and_validity)
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]

boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[-2])

boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(H, W),
)

sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
"labels": labels,
"boxes": boxes,
"whatever": torch.rand(10),
"None": None,
}

out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)

assert out["image"] is sample["image"]
assert out["whatever"] is sample["whatever"]

if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out["labels"] is sample["labels"]
else:
assert isinstance(out["labels"], torch.Tensor)
assert out["boxes"].shape[:-1] == out["labels"].shape
# This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
def test_sanitize_bounding_boxes_default_heuristic(key):
labels = torch.arange(10)
d = {key: labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBox(
[[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20),
)

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
transforms.SanitizeBoundingBoxes(labels_getter=12)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(bad_labels_key)

with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBoxes()(not_a_dict)

with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBoxes()(not_a_tensor)

with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)

with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
[
[[0, 0, 10, 10]],
[[0, 0, 10, 10]],
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
SanitizeBoundingBoxes,
ToDtype,
TransposeDimensions,
)
Expand Down
127 changes: 107 additions & 20 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import collections
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image

import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchvision import transforms as _transforms
from torchvision.ops import remove_small_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand Down Expand Up @@ -225,28 +227,113 @@ def _transform(
return inpt.transpose(*dims)


class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)

def __init__(self, min_size: float = 1.0) -> None:
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()

if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
bounding_box = query_bounding_box(flat_inputs)

# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
# format,we need to convert first just to afterwards compute the width and height again, although they were
# there in the first place for these formats.
bounding_box = F.convert_format_bounding_box(
bounding_box.as_subclass(torch.Tensor),
old_format=bounding_box.format,
new_format=datapoints.BoundingBoxFormat.XYXY,
)
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)

@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]

def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]

if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)

if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")

return dict(valid_indices=valid_indices)
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)

if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")

if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)

boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
Comment on lines +318 to +319
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this for now until we are sure about this, i.e. we have tests that guarantee this. Happy to remove if it turns out we don't need it.

image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)

params = dict(mask=mask, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]

return tree_unflatten(flat_outputs, spec)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.wrap_like(inpt, inpt[params["valid_indices"]])

if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["mask"]]

return inpt