From ae2c1145f3760626bf7b9d18ea1e9573f65e15d1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 8 Mar 2024 15:57:48 +0000 Subject: [PATCH 1/8] WIP --- test/test_transforms_v2.py | 60 ++++++++++++----- torchvision/transforms/v2/_misc.py | 22 ++----- .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_misc.py | 66 ++++++++++++++++++- 4 files changed, 115 insertions(+), 34 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0fb3ee6c11f..a502afec56e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5659,18 +5659,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): class TestSanitizeBoundingBoxes: - @pytest.mark.parametrize("min_size", (1, 10)) - @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) - @pytest.mark.parametrize("sample_type", (tuple, dict)) - def test_transform(self, min_size, labels_getter, sample_type): - - if sample_type is tuple and not isinstance(labels_getter, str): - # The "lambda inputs: inputs["labels"]" labels_getter used in this test - # doesn't work if the input is a tuple. - return - - H, W = 256, 128 - + def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): boxes_and_validity = [ ([0, 1, 10, 1], False), # Y1 == Y2 ([0, 1, 0, 20], False), # X1 == X2 @@ -5690,11 +5679,7 @@ def test_transform(self, min_size, labels_getter, sample_type): ] random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases - boxes, is_valid_mask = zip(*boxes_and_validity) - valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] - - boxes = torch.tensor(boxes) - labels = torch.arange(boxes.shape[0]) + boxes, expected_valid_mask = zip(*boxes_and_validity) boxes = tv_tensors.BoundingBoxes( boxes, @@ -5702,6 +5687,23 @@ def test_transform(self, min_size, labels_getter, sample_type): canvas_size=(H, W), ) + return boxes, expected_valid_mask + + @pytest.mark.parametrize("min_size", (1, 10)) + @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) + @pytest.mark.parametrize("sample_type", (tuple, dict)) + def test_transform(self, min_size, labels_getter, sample_type): + + if sample_type is tuple and not isinstance(labels_getter, str): + # The "lambda inputs: inputs["labels"]" labels_getter used in this test + # doesn't work if the input is a tuple. + return + + H, W = 256, 128 + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid] + + labels = torch.arange(boxes.shape[0]) masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) whatever = torch.rand(10) input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8) @@ -5747,6 +5749,30 @@ def test_transform(self, min_size, labels_getter, sample_type): # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices + # @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) + @pytest.mark.parametrize("input_type", (torch.Tensor, ))#tv_tensors.BoundingBoxes)) + def test_functional(self, input_type): + + H, W, min_size = 256, 128, 10 + + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + + if input_type is tv_tensors.BoundingBoxes: + format = canvas_size = None + else: + format, canvas_size = boxes.format, boxes.canvas_size + boxes = boxes.as_subclass(torch.Tensor) + + # boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + assert type(boxes) == torch.Tensor + f = torch.jit.script(F.sanitize_bounding_boxes) + boxes, valid = f(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + + assert_equal(valid, torch.tensor(expected_valid_mask)) + assert type(valid) == torch.Tensor + assert boxes.shape[0] == sum(valid) + assert isinstance(boxes, input_type) + def test_no_label(self): # Non-regression test for https://github.com/pytorch/vision/issues/7878 diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 6057e928115..4ed73f067c5 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -367,22 +367,14 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - boxes = cast( - tv_tensors.BoundingBoxes, - F.convert_bounding_box_format( - boxes, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ), + # TODO: or use boxes, valid = F.sanitize_bouding_boxes(...) and add both to the params dict??? + valid = F._misc._get_sanitize_bounding_boxes_mask( + boxes, + format=boxes.format, + canvas_size=boxes.canvas_size, + min_size=self.min_size, ) - ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) - # TODO: Do we really need to check for out of bounds here? All - # transforms should be clamping anyway, so this should never happen? - image_h, image_w = boxes.canvas_size - valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) - valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - - params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels) + params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxeses and the labels diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 81d5c1b9baf..8f71a7463a7 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -167,6 +167,7 @@ normalize, normalize_image, normalize_video, + sanitize_bounding_boxes, to_dtype, to_dtype_image, to_dtype_video, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 6117aa33ea4..85705f2d1c6 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional +from typing import List, Optional, Tuple import PIL.Image import torch @@ -11,7 +11,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._meta import _convert_bounding_box_format + +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize( @@ -275,3 +277,63 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type return inpt.to(dtype) + + +def sanitize_bounding_boxes( + bounding_boxes: torch.Tensor, + format: Optional[tv_tensors.BoundingBoxFormat] = None, + canvas_size: Optional[Tuple[int, int]] = None, + min_size: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + # if torch.jit.is_scripting(): + # if format is None or canvas_size is None: + # raise ValueError( + # f"format and canvas_size cannot be None in scripting mode. Got {format=} and {canvas_size=}." + # ) + # return _sanitize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) + + if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): + if format is None or canvas_size is None: + raise ValueError( + "format and canvas_size cannot be None if bounding_boxes is a pure tensor. " + # f"Got {format=} and {canvas_size=}." + # "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." + ) + valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size) + bounding_boxes = bounding_boxes[valid] + else: + if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): + raise ValueError("") + if format is not None or canvas_size is not None: + raise ValueError( + "format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " + # f"Got {format=} and {canvas_size=}. " + # "Leave those to None or pass bouding_boxes as a pure tensor." + ) + valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size) + bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) + + return bounding_boxes, valid + + +def _get_sanitize_bounding_boxes_mask( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + min_size: float = 1.0, +) -> torch.Tensor: + + bounding_boxes = _convert_bounding_box_format( + bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format + ) + + image_h, image_w = canvas_size + ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] + valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = canvas_size + valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) + valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) + #valid = valid.as_subclass(torch.Tensor) # TODO: remove this and see? + return valid From 3a3fbf757e668b260fb7a4a60e0fc5c3b9a3605f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 12 Mar 2024 11:30:46 +0000 Subject: [PATCH 2/8] More stuff --- docs/source/transforms.rst | 1 + test/test_transforms_v2.py | 43 +++++++++++--- torchvision/transforms/v2/_misc.py | 6 +- torchvision/transforms/v2/functional/_misc.py | 56 ++++++++++++++----- 4 files changed, 81 insertions(+), 25 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 056d1589e84..5c21897cdf8 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -414,6 +414,7 @@ Functionals v2.functional.normalize v2.functional.erase + v2.functional.sanitize_bounding_boxes v2.functional.clamp_bounding_boxes v2.functional.uniform_temporal_subsample diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a502afec56e..9e2e5cc6d55 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5749,9 +5749,10 @@ def test_transform(self, min_size, labels_getter, sample_type): # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices - # @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) - @pytest.mark.parametrize("input_type", (torch.Tensor, ))#tv_tensors.BoundingBoxes)) + @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) def test_functional(self, input_type): + # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some + # redundancy with test_transform() in terms of correctness checks. But that's OK. H, W, min_size = 256, 128, 10 @@ -5760,19 +5761,32 @@ def test_functional(self, input_type): if input_type is tv_tensors.BoundingBoxes: format = canvas_size = None else: - format, canvas_size = boxes.format, boxes.canvas_size + # just passing "XYXY" explicitly to make sure we support strings + format, canvas_size = "XYXY", boxes.canvas_size boxes = boxes.as_subclass(torch.Tensor) - # boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) - assert type(boxes) == torch.Tensor - f = torch.jit.script(F.sanitize_bounding_boxes) - boxes, valid = f(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) assert_equal(valid, torch.tensor(expected_valid_mask)) assert type(valid) == torch.Tensor assert boxes.shape[0] == sum(valid) assert isinstance(boxes, input_type) + def test_kernel(self): + H, W, min_size = 256, 128, 10 + boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + + format, canvas_size = boxes.format, boxes.canvas_size + boxes = boxes.as_subclass(torch.Tensor) + + check_kernel( + F.sanitize_bounding_boxes, + input=boxes, + format=format, + canvas_size=canvas_size, + check_batched_vs_unbatched=False, + ) + def test_no_label(self): # Non-regression test for https://github.com/pytorch/vision/issues/7878 @@ -5809,3 +5823,18 @@ def test_errors(self): with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) + + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): + F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): + F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10)) + + with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"): + F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"): + F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"): + F.sanitize_bounding_boxes(good_bbox.tolist()) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 4ed73f067c5..d59ef03154a 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import PIL.Image @@ -367,7 +367,7 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - # TODO: or use boxes, valid = F.sanitize_bouding_boxes(...) and add both to the params dict??? + # Alternatively we could use `boxes, valid = F.sanitize_bouding_boxes(...)` and pass both to the params dict? valid = F._misc._get_sanitize_bounding_boxes_mask( boxes, format=boxes.format, @@ -377,7 +377,7 @@ def forward(self, *inputs: Any) -> Any: params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxeses and the labels + # _transform() will only care about BoundingBoxes and the labels self._transform(inpt, params) for inpt in flat_inputs ] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 85705f2d1c6..c4e6cd6b9df 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -285,32 +285,59 @@ def sanitize_bounding_boxes( canvas_size: Optional[Tuple[int, int]] = None, min_size: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: - # if torch.jit.is_scripting(): - # if format is None or canvas_size is None: - # raise ValueError( - # f"format and canvas_size cannot be None in scripting mode. Got {format=} and {canvas_size=}." - # ) - # return _sanitize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) - + """Remove degenerate/invalid bounding boxes and return the corresponding indexing mask. + + This removes bounding boxes that: + + - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - have any coordinate outside of their corresponding image. You may want to + call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify bounding boxes but once at the end should be enough in most + cases. + + Args: + bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized. + format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes. + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + canvas_size (tuple of int, optional): The canvas_size of the bounding boxes + (size of the corresponding image/video). + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + + Returns: + out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. + """ if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): if format is None or canvas_size is None: raise ValueError( "format and canvas_size cannot be None if bounding_boxes is a pure tensor. " - # f"Got {format=} and {canvas_size=}." - # "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." + f"Got format={format} and canvas_size={canvas_size}." + "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." ) - valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size) + if isinstance(format, str): + format = tv_tensors.BoundingBoxFormat[format.upper()] + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size + ) bounding_boxes = bounding_boxes[valid] else: if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): - raise ValueError("") + raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.") if format is not None or canvas_size is not None: raise ValueError( "format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " - # f"Got {format=} and {canvas_size=}. " - # "Leave those to None or pass bouding_boxes as a pure tensor." + f"Got format={format} and canvas_size={canvas_size}. " + "Leave those to None or pass bouding_boxes as a pure tensor." ) - valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size) + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size + ) bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) return bounding_boxes, valid @@ -335,5 +362,4 @@ def _get_sanitize_bounding_boxes_mask( image_h, image_w = canvas_size valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) - #valid = valid.as_subclass(torch.Tensor) # TODO: remove this and see? return valid From 5e0fa5f5af11882b58bbb737e0c0191b8ebc1a6c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 11:49:21 +0000 Subject: [PATCH 3/8] Address comments --- test/test_transforms_v2.py | 10 +++++++++- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/transforms/v2/_misc.py | 8 +------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9e2e5cc6d55..8a92752d739 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5800,7 +5800,7 @@ def test_no_label(self): assert isinstance(out_img, tv_tensors.Image) assert isinstance(out_boxes, tv_tensors.BoundingBoxes) - def test_errors(self): + def test_errors_transform(self): good_bbox = tv_tensors.BoundingBoxes( [[0, 0, 10, 10]], format=tv_tensors.BoundingBoxFormat.XYXY, @@ -5824,6 +5824,14 @@ def test_errors(self): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) + def test_errors_functional(self): + + good_bbox = tv_tensors.BoundingBoxes( + [[0, 0, 10, 10]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(20, 20), + ) + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index f7e5a6be2dd..a592371171d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -123,7 +123,7 @@ def _extract_image_targets( if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxeses, Masks and Labels or OneHotLabels." + "BoundingBoxes, Masks and Labels or OneHotLabels." ) targets = [] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index d59ef03154a..ebe11cf9534 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -367,7 +367,6 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - # Alternatively we could use `boxes, valid = F.sanitize_bouding_boxes(...)` and pass both to the params dict? valid = F._misc._get_sanitize_bounding_boxes_mask( boxes, format=boxes.format, @@ -375,12 +374,7 @@ def forward(self, *inputs: Any) -> Any: min_size=self.min_size, ) params = dict(valid=valid, labels=labels) - flat_outputs = [ - # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxes and the labels - self._transform(inpt, params) - for inpt in flat_inputs - ] + flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] return tree_unflatten(flat_outputs, spec) From 26ac1b08d2509f1f60a50f8bb5cdf625b13edee7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 13:09:08 +0000 Subject: [PATCH 4/8] Allow SanitizeBoundingBoxes to sanitize more labels --- test/test_transforms_v2.py | 25 ++++++++++++++-- torchvision/transforms/v2/_misc.py | 48 +++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8a92752d739..c45d12abeab 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5690,7 +5690,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): return boxes, expected_valid_mask @pytest.mark.parametrize("min_size", (1, 10)) - @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) + @pytest.mark.parametrize( + "labels_getter", + ( + "default", + lambda inputs: inputs["labels"], + lambda inputs: (inputs["labels"], inputs["other_labels"]), + None, + lambda inputs: None, + ), + ) @pytest.mark.parametrize("sample_type", (tuple, dict)) def test_transform(self, min_size, labels_getter, sample_type): @@ -5705,12 +5714,16 @@ def test_transform(self, min_size, labels_getter, sample_type): labels = torch.arange(boxes.shape[0]) masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) + # other_labels corresponds to properties from COCO like iscrowd, area... + # We only sanitize it when labels_getter returns a tuple + other_labels = torch.arange(boxes.shape[0]) whatever = torch.rand(10) input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8) sample = { "image": input_img, "labels": labels, "boxes": boxes, + "other_labels": other_labels, "whatever": whatever, "None": None, "masks": masks, @@ -5725,12 +5738,14 @@ def test_transform(self, min_size, labels_getter, sample_type): if sample_type is tuple: out_image = out[0] out_labels = out[1]["labels"] + out_other_labels = out[1]["other_labels"] out_boxes = out[1]["boxes"] out_masks = out[1]["masks"] out_whatever = out[1]["whatever"] else: out_image = out["image"] out_labels = out["labels"] + out_other_labels = out["other_labels"] out_boxes = out["boxes"] out_masks = out["masks"] out_whatever = out["whatever"] @@ -5741,14 +5756,20 @@ def test_transform(self, min_size, labels_getter, sample_type): assert isinstance(out_boxes, tv_tensors.BoundingBoxes) assert isinstance(out_masks, tv_tensors.Mask) - if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): + if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None): assert out_labels is labels + assert out_other_labels is other_labels else: assert isinstance(out_labels, torch.Tensor) assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0] # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices + if callable(labels_getter) and type(labels_getter(sample)) is tuple: + assert_equal(out_other_labels, out_labels) + else: + assert_equal(out_other_labels, other_labels) + @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) def test_functional(self, input_type): # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index ebe11cf9534..b8bbed38908 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -319,6 +319,9 @@ class SanitizeBoundingBoxes(Transform): - have any coordinate outside of their corresponding image. You may want to call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. + It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO + (see ``labels_getter`` parameter). + It is recommended to call it at the end of a pipeline, before passing the input to the models. It is critical to call this transform if :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. @@ -328,12 +331,18 @@ class SanitizeBoundingBoxes(Transform): Args: min_size (float, optional) The size below which bounding boxes are removed. Default is 1. - labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. + labels_getter (callable or str or None, optional): indicates how to identify the labels in the input + (or anything else that needs to be sanitized along with the bounding boxes). By default, this will try to find a "labels" key in the input (case-insensitive), if the input is a dict or it is a tuple whose second element is a dict. This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. - It can also be a callable that takes the same input - as the transform, and returns the labels. + + It can also be a callable that takes the same input as the transform, and returns either: + + - A single tensor (the labels) + - A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes. + This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties + from COCO. """ def __init__( @@ -354,18 +363,29 @@ def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] labels = self._labels_getter(inputs) - if labels is not None and not isinstance(labels, torch.Tensor): - raise ValueError( - f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead." - ) + if labels is not None: + msg = "The labels in the input to forward() must be a tensor or None, got {type} instead." + if isinstance(labels, torch.Tensor): + labels = (labels,) + elif isinstance(labels, (tuple, list)): + labels = tuple(labels) + for entry in labels: + if not isinstance(entry, torch.Tensor): + # TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask] + raise ValueError(msg.format(type=type(entry))) + else: + raise ValueError(msg.format(type=type(labels))) flat_inputs, spec = tree_flatten(inputs) boxes = get_bounding_boxes(flat_inputs) - if labels is not None and boxes.shape[0] != labels.shape[0]: - raise ValueError( - f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." - ) + if labels is not None: + for label in labels: + if boxes.shape[0] != label.shape[0]: + raise ValueError( + f"Number of boxes (shape={boxes.shape}) and must match the number of labels." + f"Found labels with shape={label.shape})." + ) valid = F._misc._get_sanitize_bounding_boxes_mask( boxes, @@ -379,7 +399,7 @@ def forward(self, *inputs: Any) -> Any: return tree_unflatten(flat_outputs, spec) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - is_label = inpt is not None and inpt is params["labels"] + is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) if not (is_label or is_bounding_boxes_or_mask): @@ -389,5 +409,5 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if is_label: return output - - return tv_tensors.wrap(output, like=inpt) + else: + return tv_tensors.wrap(output, like=inpt) From 41a35a4cafbecf24503323dcb31817f1366d2c70 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 13:45:37 +0000 Subject: [PATCH 5/8] mypy? --- torchvision/transforms/v2/_misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 00bad2be253..798d515da59 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, Tuple import PIL.Image @@ -350,7 +350,7 @@ class SanitizeBoundingBoxes(Transform): def __init__( self, min_size: float = 1.0, - labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", + labels_getter: Union[Callable[[Any], Optional[Union[torch.Tensor, Tuple[torch.tensor]]]], str, None] = "default", ) -> None: super().__init__() From 0f2e09ea1de35c23d5c68b4689dc3cfabc74e09e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 13:53:37 +0000 Subject: [PATCH 6/8] Address comments --- test/test_transforms_v2.py | 3 ++- torchvision/transforms/v2/_misc.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index da4ba3141d5..49855400e85 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5712,6 +5712,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): "default", lambda inputs: inputs["labels"], lambda inputs: (inputs["labels"], inputs["other_labels"]), + lambda inputs: [inputs["labels"], inputs["other_labels"]], None, lambda inputs: None, ), @@ -5781,7 +5782,7 @@ def test_transform(self, min_size, labels_getter, sample_type): # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices - if callable(labels_getter) and type(labels_getter(sample)) is tuple: + if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)): assert_equal(out_other_labels, out_labels) else: assert_equal(out_other_labels, other_labels) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 798d515da59..5034825eb32 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -345,12 +345,16 @@ class SanitizeBoundingBoxes(Transform): - A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes. This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties from COCO. + + If ``labels_getter`` is None then only bounding boxes are sanitized. """ def __init__( self, min_size: float = 1.0, - labels_getter: Union[Callable[[Any], Optional[Union[torch.Tensor, Tuple[torch.tensor]]]], str, None] = "default", + labels_getter: Union[ + Callable[[Any], Optional[Union[torch.Tensor, Tuple[torch.tensor]]]], str, None + ] = "default", ) -> None: super().__init__() @@ -370,7 +374,6 @@ def forward(self, *inputs: Any) -> Any: if isinstance(labels, torch.Tensor): labels = (labels,) elif isinstance(labels, (tuple, list)): - labels = tuple(labels) for entry in labels: if not isinstance(entry, torch.Tensor): # TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask] From f96dbeae1f6b3229cc4bcac739e34eb9ef3a0cf4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 14:08:19 +0000 Subject: [PATCH 7/8] mypy --- torchvision/transforms/v2/_misc.py | 6 ++---- torchvision/transforms/v2/_utils.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 5034825eb32..ad2c08150cc 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import PIL.Image @@ -352,9 +352,7 @@ class SanitizeBoundingBoxes(Transform): def __init__( self, min_size: float = 1.0, - labels_getter: Union[ - Callable[[Any], Optional[Union[torch.Tensor, Tuple[torch.tensor]]]], str, None - ] = "default", + labels_getter: Union[Callable[[Any], Any], str, None] = "default", ) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 6147180a986..6b69a398ccf 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -4,7 +4,7 @@ import numbers from contextlib import suppress -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Literal, Sequence, Tuple, Type, Union import PIL.Image import torch @@ -140,8 +140,8 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: def _parse_labels_getter( - labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None] -) -> Callable[[Any], Optional[torch.Tensor]]: + labels_getter: Union[str, Callable[[Any], Any], None] +) -> Callable[[Any], Any]: if labels_getter == "default": return _find_labels_default_heuristic elif callable(labels_getter): From c9bd6c87d10a9b3eb9f141460a710b71c9b948f1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 14:15:37 +0000 Subject: [PATCH 8/8] lint --- torchvision/transforms/v2/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 6b69a398ccf..e7cde4c5c33 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -139,9 +139,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: return inputs[candidate_key] -def _parse_labels_getter( - labels_getter: Union[str, Callable[[Any], Any], None] -) -> Callable[[Any], Any]: +def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) -> Callable[[Any], Any]: if labels_getter == "default": return _find_labels_default_heuristic elif callable(labels_getter):