From ae2c1145f3760626bf7b9d18ea1e9573f65e15d1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 8 Mar 2024 15:57:48 +0000 Subject: [PATCH 1/3] WIP --- test/test_transforms_v2.py | 60 ++++++++++++----- torchvision/transforms/v2/_misc.py | 22 ++----- .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_misc.py | 66 ++++++++++++++++++- 4 files changed, 115 insertions(+), 34 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0fb3ee6c11f..a502afec56e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5659,18 +5659,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): class TestSanitizeBoundingBoxes: - @pytest.mark.parametrize("min_size", (1, 10)) - @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) - @pytest.mark.parametrize("sample_type", (tuple, dict)) - def test_transform(self, min_size, labels_getter, sample_type): - - if sample_type is tuple and not isinstance(labels_getter, str): - # The "lambda inputs: inputs["labels"]" labels_getter used in this test - # doesn't work if the input is a tuple. - return - - H, W = 256, 128 - + def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): boxes_and_validity = [ ([0, 1, 10, 1], False), # Y1 == Y2 ([0, 1, 0, 20], False), # X1 == X2 @@ -5690,11 +5679,7 @@ def test_transform(self, min_size, labels_getter, sample_type): ] random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases - boxes, is_valid_mask = zip(*boxes_and_validity) - valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] - - boxes = torch.tensor(boxes) - labels = torch.arange(boxes.shape[0]) + boxes, expected_valid_mask = zip(*boxes_and_validity) boxes = tv_tensors.BoundingBoxes( boxes, @@ -5702,6 +5687,23 @@ def test_transform(self, min_size, labels_getter, sample_type): canvas_size=(H, W), ) + return boxes, expected_valid_mask + + @pytest.mark.parametrize("min_size", (1, 10)) + @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) + @pytest.mark.parametrize("sample_type", (tuple, dict)) + def test_transform(self, min_size, labels_getter, sample_type): + + if sample_type is tuple and not isinstance(labels_getter, str): + # The "lambda inputs: inputs["labels"]" labels_getter used in this test + # doesn't work if the input is a tuple. + return + + H, W = 256, 128 + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid] + + labels = torch.arange(boxes.shape[0]) masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) whatever = torch.rand(10) input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8) @@ -5747,6 +5749,30 @@ def test_transform(self, min_size, labels_getter, sample_type): # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices + # @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) + @pytest.mark.parametrize("input_type", (torch.Tensor, ))#tv_tensors.BoundingBoxes)) + def test_functional(self, input_type): + + H, W, min_size = 256, 128, 10 + + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + + if input_type is tv_tensors.BoundingBoxes: + format = canvas_size = None + else: + format, canvas_size = boxes.format, boxes.canvas_size + boxes = boxes.as_subclass(torch.Tensor) + + # boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + assert type(boxes) == torch.Tensor + f = torch.jit.script(F.sanitize_bounding_boxes) + boxes, valid = f(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + + assert_equal(valid, torch.tensor(expected_valid_mask)) + assert type(valid) == torch.Tensor + assert boxes.shape[0] == sum(valid) + assert isinstance(boxes, input_type) + def test_no_label(self): # Non-regression test for https://github.com/pytorch/vision/issues/7878 diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 6057e928115..4ed73f067c5 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -367,22 +367,14 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - boxes = cast( - tv_tensors.BoundingBoxes, - F.convert_bounding_box_format( - boxes, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ), + # TODO: or use boxes, valid = F.sanitize_bouding_boxes(...) and add both to the params dict??? + valid = F._misc._get_sanitize_bounding_boxes_mask( + boxes, + format=boxes.format, + canvas_size=boxes.canvas_size, + min_size=self.min_size, ) - ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) - # TODO: Do we really need to check for out of bounds here? All - # transforms should be clamping anyway, so this should never happen? - image_h, image_w = boxes.canvas_size - valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) - valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - - params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels) + params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxeses and the labels diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 81d5c1b9baf..8f71a7463a7 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -167,6 +167,7 @@ normalize, normalize_image, normalize_video, + sanitize_bounding_boxes, to_dtype, to_dtype_image, to_dtype_video, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 6117aa33ea4..85705f2d1c6 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional +from typing import List, Optional, Tuple import PIL.Image import torch @@ -11,7 +11,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._meta import _convert_bounding_box_format + +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize( @@ -275,3 +277,63 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type return inpt.to(dtype) + + +def sanitize_bounding_boxes( + bounding_boxes: torch.Tensor, + format: Optional[tv_tensors.BoundingBoxFormat] = None, + canvas_size: Optional[Tuple[int, int]] = None, + min_size: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + # if torch.jit.is_scripting(): + # if format is None or canvas_size is None: + # raise ValueError( + # f"format and canvas_size cannot be None in scripting mode. Got {format=} and {canvas_size=}." + # ) + # return _sanitize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) + + if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): + if format is None or canvas_size is None: + raise ValueError( + "format and canvas_size cannot be None if bounding_boxes is a pure tensor. " + # f"Got {format=} and {canvas_size=}." + # "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." + ) + valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size) + bounding_boxes = bounding_boxes[valid] + else: + if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): + raise ValueError("") + if format is not None or canvas_size is not None: + raise ValueError( + "format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " + # f"Got {format=} and {canvas_size=}. " + # "Leave those to None or pass bouding_boxes as a pure tensor." + ) + valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size) + bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) + + return bounding_boxes, valid + + +def _get_sanitize_bounding_boxes_mask( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + min_size: float = 1.0, +) -> torch.Tensor: + + bounding_boxes = _convert_bounding_box_format( + bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format + ) + + image_h, image_w = canvas_size + ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] + valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = canvas_size + valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) + valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) + #valid = valid.as_subclass(torch.Tensor) # TODO: remove this and see? + return valid From 3a3fbf757e668b260fb7a4a60e0fc5c3b9a3605f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 12 Mar 2024 11:30:46 +0000 Subject: [PATCH 2/3] More stuff --- docs/source/transforms.rst | 1 + test/test_transforms_v2.py | 43 +++++++++++--- torchvision/transforms/v2/_misc.py | 6 +- torchvision/transforms/v2/functional/_misc.py | 56 ++++++++++++++----- 4 files changed, 81 insertions(+), 25 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 056d1589e84..5c21897cdf8 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -414,6 +414,7 @@ Functionals v2.functional.normalize v2.functional.erase + v2.functional.sanitize_bounding_boxes v2.functional.clamp_bounding_boxes v2.functional.uniform_temporal_subsample diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a502afec56e..9e2e5cc6d55 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5749,9 +5749,10 @@ def test_transform(self, min_size, labels_getter, sample_type): # This works because we conveniently set labels to arange(num_boxes) assert out_labels.tolist() == valid_indices - # @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) - @pytest.mark.parametrize("input_type", (torch.Tensor, ))#tv_tensors.BoundingBoxes)) + @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) def test_functional(self, input_type): + # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some + # redundancy with test_transform() in terms of correctness checks. But that's OK. H, W, min_size = 256, 128, 10 @@ -5760,19 +5761,32 @@ def test_functional(self, input_type): if input_type is tv_tensors.BoundingBoxes: format = canvas_size = None else: - format, canvas_size = boxes.format, boxes.canvas_size + # just passing "XYXY" explicitly to make sure we support strings + format, canvas_size = "XYXY", boxes.canvas_size boxes = boxes.as_subclass(torch.Tensor) - # boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) - assert type(boxes) == torch.Tensor - f = torch.jit.script(F.sanitize_bounding_boxes) - boxes, valid = f(boxes, format=format, canvas_size=canvas_size, min_size=min_size) + boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) assert_equal(valid, torch.tensor(expected_valid_mask)) assert type(valid) == torch.Tensor assert boxes.shape[0] == sum(valid) assert isinstance(boxes, input_type) + def test_kernel(self): + H, W, min_size = 256, 128, 10 + boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + + format, canvas_size = boxes.format, boxes.canvas_size + boxes = boxes.as_subclass(torch.Tensor) + + check_kernel( + F.sanitize_bounding_boxes, + input=boxes, + format=format, + canvas_size=canvas_size, + check_batched_vs_unbatched=False, + ) + def test_no_label(self): # Non-regression test for https://github.com/pytorch/vision/issues/7878 @@ -5809,3 +5823,18 @@ def test_errors(self): with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) + + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): + F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): + F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10)) + + with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"): + F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"): + F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None) + + with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"): + F.sanitize_bounding_boxes(good_bbox.tolist()) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 4ed73f067c5..d59ef03154a 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import PIL.Image @@ -367,7 +367,7 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - # TODO: or use boxes, valid = F.sanitize_bouding_boxes(...) and add both to the params dict??? + # Alternatively we could use `boxes, valid = F.sanitize_bouding_boxes(...)` and pass both to the params dict? valid = F._misc._get_sanitize_bounding_boxes_mask( boxes, format=boxes.format, @@ -377,7 +377,7 @@ def forward(self, *inputs: Any) -> Any: params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxeses and the labels + # _transform() will only care about BoundingBoxes and the labels self._transform(inpt, params) for inpt in flat_inputs ] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 85705f2d1c6..c4e6cd6b9df 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -285,32 +285,59 @@ def sanitize_bounding_boxes( canvas_size: Optional[Tuple[int, int]] = None, min_size: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: - # if torch.jit.is_scripting(): - # if format is None or canvas_size is None: - # raise ValueError( - # f"format and canvas_size cannot be None in scripting mode. Got {format=} and {canvas_size=}." - # ) - # return _sanitize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) - + """Remove degenerate/invalid bounding boxes and return the corresponding indexing mask. + + This removes bounding boxes that: + + - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - have any coordinate outside of their corresponding image. You may want to + call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify bounding boxes but once at the end should be enough in most + cases. + + Args: + bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized. + format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes. + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + canvas_size (tuple of int, optional): The canvas_size of the bounding boxes + (size of the corresponding image/video). + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + + Returns: + out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. + """ if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): if format is None or canvas_size is None: raise ValueError( "format and canvas_size cannot be None if bounding_boxes is a pure tensor. " - # f"Got {format=} and {canvas_size=}." - # "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." + f"Got format={format} and canvas_size={canvas_size}." + "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." ) - valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size) + if isinstance(format, str): + format = tv_tensors.BoundingBoxFormat[format.upper()] + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size + ) bounding_boxes = bounding_boxes[valid] else: if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): - raise ValueError("") + raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.") if format is not None or canvas_size is not None: raise ValueError( "format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " - # f"Got {format=} and {canvas_size=}. " - # "Leave those to None or pass bouding_boxes as a pure tensor." + f"Got format={format} and canvas_size={canvas_size}. " + "Leave those to None or pass bouding_boxes as a pure tensor." ) - valid = _get_sanitize_bounding_boxes_mask(bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size) + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size + ) bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) return bounding_boxes, valid @@ -335,5 +362,4 @@ def _get_sanitize_bounding_boxes_mask( image_h, image_w = canvas_size valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) - #valid = valid.as_subclass(torch.Tensor) # TODO: remove this and see? return valid From 5e0fa5f5af11882b58bbb737e0c0191b8ebc1a6c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 15 Mar 2024 11:49:21 +0000 Subject: [PATCH 3/3] Address comments --- test/test_transforms_v2.py | 10 +++++++++- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/transforms/v2/_misc.py | 8 +------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9e2e5cc6d55..8a92752d739 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5800,7 +5800,7 @@ def test_no_label(self): assert isinstance(out_img, tv_tensors.Image) assert isinstance(out_boxes, tv_tensors.BoundingBoxes) - def test_errors(self): + def test_errors_transform(self): good_bbox = tv_tensors.BoundingBoxes( [[0, 0, 10, 10]], format=tv_tensors.BoundingBoxFormat.XYXY, @@ -5824,6 +5824,14 @@ def test_errors(self): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) + def test_errors_functional(self): + + good_bbox = tv_tensors.BoundingBoxes( + [[0, 0, 10, 10]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(20, 20), + ) + with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"): F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index f7e5a6be2dd..a592371171d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -123,7 +123,7 @@ def _extract_image_targets( if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxeses, Masks and Labels or OneHotLabels." + "BoundingBoxes, Masks and Labels or OneHotLabels." ) targets = [] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index d59ef03154a..ebe11cf9534 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -367,7 +367,6 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - # Alternatively we could use `boxes, valid = F.sanitize_bouding_boxes(...)` and pass both to the params dict? valid = F._misc._get_sanitize_bounding_boxes_mask( boxes, format=boxes.format, @@ -375,12 +374,7 @@ def forward(self, *inputs: Any) -> Any: min_size=self.min_size, ) params = dict(valid=valid, labels=labels) - flat_outputs = [ - # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxes and the labels - self._transform(inpt, params) - for inpt in flat_inputs - ] + flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] return tree_unflatten(flat_outputs, spec)