diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b881ebc502a..ff772c5151f 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,6 @@ import itertools import pathlib +import random import re import warnings from collections import defaultdict @@ -2355,3 +2356,118 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): out["label"] = torch.tensor(out["label"]) assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes + + +@pytest.mark.parametrize("min_size", (1, 10)) +@pytest.mark.parametrize( + "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) +) +def test_sanitize_bounding_boxes(min_size, labels_getter): + H, W = 256, 128 + + boxes_and_validity = [ + ([0, 1, 10, 1], False), # Y1 == Y2 + ([0, 1, 0, 20], False), # X1 == X2 + ([0, 0, min_size - 1, 10], False), # H < min_size + ([0, 0, 10, min_size - 1], False), # W < min_size + ([0, 0, 10, H + 1], False), # Y2 > H + ([0, 0, W + 1, 10], False), # X2 > W + ([-1, 1, 10, 20], False), # any < 0 + ([0, 0, -1, 20], False), # any < 0 + ([0, 0, -10, -1], False), # any < 0 + ([0, 0, min_size, 10], True), # H < min_size + ([0, 0, 10, min_size], True), # W < min_size + ([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1? + ([1, 1, 30, 20], True), + ([0, 0, 10, 10], True), + ([1, 1, 30, 20], True), + ] + + random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases + boxes, is_valid_mask = zip(*boxes_and_validity) + valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] + + boxes = torch.tensor(boxes) + labels = torch.arange(boxes.shape[-2]) + + boxes = datapoints.BoundingBox( + boxes, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(H, W), + ) + + sample = { + "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), + "labels": labels, + "boxes": boxes, + "whatever": torch.rand(10), + "None": None, + } + + out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) + + assert out["image"] is sample["image"] + assert out["whatever"] is sample["whatever"] + + if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): + assert out["labels"] is sample["labels"] + else: + assert isinstance(out["labels"], torch.Tensor) + assert out["boxes"].shape[:-1] == out["labels"].shape + # This works because we conveniently set labels to arange(num_boxes) + assert out["labels"].tolist() == valid_indices + + +@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) +def test_sanitize_bounding_boxes_default_heuristic(key): + labels = torch.arange(10) + d = {key: labels} + assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels + + if key.lower() != "labels": + # If "labels" is in the dict (case-insensitive), + # it takes precedence over other keys which would otherwise be a match + d = {key: "something_else", "labels": labels} + assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels + + +def test_sanitize_bounding_boxes_errors(): + + good_bbox = datapoints.BoundingBox( + [[0, 0, 10, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(20, 20), + ) + + with pytest.raises(ValueError, match="min_size must be >= 1"): + transforms.SanitizeBoundingBoxes(min_size=0) + with pytest.raises(ValueError, match="labels_getter should either be a str"): + transforms.SanitizeBoundingBoxes(labels_getter=12) + + with pytest.raises(ValueError, match="Could not infer where the labels are"): + bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} + transforms.SanitizeBoundingBoxes()(bad_labels_key) + + with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): + not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) + transforms.SanitizeBoundingBoxes()(not_a_dict) + + with pytest.raises(ValueError, match="must be a tensor"): + not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} + transforms.SanitizeBoundingBoxes()(not_a_tensor) + + with pytest.raises(ValueError, match="Number of boxes"): + different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} + transforms.SanitizeBoundingBoxes()(different_sizes) + + with pytest.raises(ValueError, match="boxes must be of shape"): + bad_bbox = datapoints.BoundingBox( # batch with 2 elements + [ + [[0, 0, 10, 10]], + [[0, 0, 10, 10]], + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(20, 20), + ) + different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} + transforms.SanitizeBoundingBoxes()(different_sizes) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index a640d726cef..ff3b758454a 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -49,7 +49,7 @@ LinearTransformation, Normalize, PermuteDimensions, - RemoveSmallBoundingBoxes, + SanitizeBoundingBoxes, ToDtype, TransposeDimensions, ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b398227b480..caed3eec904 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,12 +1,14 @@ +import collections import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from contextlib import suppress +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import transforms as _transforms -from torchvision.ops import remove_small_boxes from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform @@ -225,28 +227,113 @@ def _transform( return inpt.transpose(*dims) -class RemoveSmallBoundingBoxes(Transform): - _transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) +class SanitizeBoundingBoxes(Transform): + # This removes boxes and their corresponding labels: + # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) + # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - def __init__(self, min_size: float = 1.0) -> None: + def __init__( + self, + min_size: float = 1.0, + labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", + ) -> None: super().__init__() + + if min_size < 1: + raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - bounding_box = query_bounding_box(flat_inputs) - - # TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to - # be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH - # format,we need to convert first just to afterwards compute the width and height again, although they were - # there in the first place for these formats. - bounding_box = F.convert_format_bounding_box( - bounding_box.as_subclass(torch.Tensor), - old_format=bounding_box.format, - new_format=datapoints.BoundingBoxFormat.XYXY, - ) - valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) + self.labels_getter = labels_getter + self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] + if labels_getter == "default": + self._labels_getter = self._find_labels_default_heuristic + elif callable(labels_getter): + self._labels_getter = labels_getter + elif isinstance(labels_getter, str): + self._labels_getter = lambda inputs: inputs[labels_getter] + elif labels_getter is None: + self._labels_getter = None + else: + raise ValueError( + "labels_getter should either be a str, callable, or 'default'. " + f"Got {labels_getter} of type {type(labels_getter)}." + ) + + @staticmethod + def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive + # Returns None if nothing is found + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + if candidate_key is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no samples and it is by design, pass labels_getter=None." + ) + return inputs[candidate_key] + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + + if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): + raise ValueError( + f"If labels_getter is a str or 'default' (got {self.labels_getter}), " + f"then the input to forward() must be a dict. Got {type(inputs)} instead." + ) + + if self._labels_getter is None: + labels = None + else: + labels = self._labels_getter(inputs) + if labels is not None and not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") - return dict(valid_indices=valid_indices) + flat_inputs, spec = tree_flatten(inputs) + # TODO: this enforces one single BoundingBox entry. + # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... + # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? + boxes = query_bounding_box(flat_inputs) + + if boxes.ndim != 2: + raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") + + if labels is not None and boxes.shape[0] != labels.shape[0]: + raise ValueError( + f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." + ) + + boxes = cast( + datapoints.BoundingBox, + F.convert_format_bounding_box( + boxes, + new_format=datapoints.BoundingBoxFormat.XYXY, + ), + ) + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = boxes.spatial_size + mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) + + params = dict(mask=mask, labels=labels) + flat_outputs = [ + # Even-though it may look like we're transforming all inputs, we don't: + # _transform() will only care about BoundingBoxes and the labels + self._transform(inpt, params) + for inpt in flat_inputs + ] + + return tree_unflatten(flat_outputs, spec) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) + + if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): + inpt = inpt[params["mask"]] + + return inpt