Skip to content

Add sanitize_bounding_boxes kernel/functional #8308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ Functionals

v2.functional.normalize
v2.functional.erase
v2.functional.sanitize_bounding_boxes
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample

Expand Down
99 changes: 81 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5675,18 +5675,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


class TestSanitizeBoundingBoxes:
@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128

def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
Expand All @@ -5706,18 +5695,31 @@ def test_transform(self, min_size, labels_getter, sample_type):
]

random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, is_valid_mask = zip(*boxes_and_validity)
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]

boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[0])
boxes, expected_valid_mask = zip(*boxes_and_validity)

boxes = tv_tensors.BoundingBoxes(
boxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W),
)

return boxes, expected_valid_mask

@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That test did not change, I just moved the input generation above in a function so that it can be reused in the newly added tests below


if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]

labels = torch.arange(boxes.shape[0])
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
Expand Down Expand Up @@ -5763,6 +5765,44 @@ def test_transform(self, min_size, labels_getter, sample_type):
# This works because we conveniently set labels to arange(num_boxes)
assert out_labels.tolist() == valid_indices

@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
def test_functional(self, input_type):
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
# redundancy with test_transform() in terms of correctness checks. But that's OK.

H, W, min_size = 256, 128, 10

boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)

if input_type is tv_tensors.BoundingBoxes:
format = canvas_size = None
else:
# just passing "XYXY" explicitly to make sure we support strings
format, canvas_size = "XYXY", boxes.canvas_size
boxes = boxes.as_subclass(torch.Tensor)

boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size)

assert_equal(valid, torch.tensor(expected_valid_mask))
assert type(valid) == torch.Tensor
assert boxes.shape[0] == sum(valid)
assert isinstance(boxes, input_type)

def test_kernel(self):
H, W, min_size = 256, 128, 10
boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)

format, canvas_size = boxes.format, boxes.canvas_size
boxes = boxes.as_subclass(torch.Tensor)

check_kernel(
F.sanitize_bounding_boxes,
input=boxes,
format=format,
canvas_size=canvas_size,
check_batched_vs_unbatched=False,
)

def test_no_label(self):
# Non-regression test for https://github.com/pytorch/vision/issues/7878

Expand All @@ -5776,7 +5816,7 @@ def test_no_label(self):
assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)

def test_errors(self):
def test_errors_transform(self):
good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
format=tv_tensors.BoundingBoxFormat.XYXY,
Expand All @@ -5799,3 +5839,26 @@ def test_errors(self):
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)

def test_errors_functional(self):

good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(20, 20),
)

with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10))

with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)

with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
F.sanitize_bounding_boxes(good_bbox.tolist())
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _extract_image_targets(
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxeses, Masks and Labels or OneHotLabels."
"BoundingBoxes, Masks and Labels or OneHotLabels."
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby

)

targets = []
Expand Down
30 changes: 8 additions & 22 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import PIL.Image

Expand Down Expand Up @@ -369,28 +369,14 @@ def forward(self, *inputs: Any) -> Any:
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)

boxes = cast(
tv_tensors.BoundingBoxes,
F.convert_bounding_box_format(
boxes,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
),
valid = F._misc._get_sanitize_bounding_boxes_mask(
boxes,
format=boxes.format,
canvas_size=boxes.canvas_size,
min_size=self.min_size,
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.canvas_size
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)

params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
params = dict(valid=valid, labels=labels)
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]

return tree_unflatten(flat_outputs, spec)

Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
normalize,
normalize_image,
normalize_video,
sanitize_bounding_boxes,
to_dtype,
to_dtype_image,
to_dtype_video,
Expand Down
92 changes: 90 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional
from typing import List, Optional, Tuple

import PIL.Image
import torch
Expand All @@ -11,7 +11,9 @@

from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal
from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor


def normalize(
Expand Down Expand Up @@ -275,3 +277,89 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
# We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
return inpt.to(dtype)


def sanitize_bounding_boxes(
bounding_boxes: torch.Tensor,
format: Optional[tv_tensors.BoundingBoxFormat] = None,
canvas_size: Optional[Tuple[int, int]] = None,
min_size: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.

This removes bounding boxes that:

- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.

It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
If you want to be extra careful, you may call it after all transforms that
may modify bounding boxes but once at the end should be enough in most
cases.

Args:
bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized.
format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes.
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
canvas_size (tuple of int, optional): The canvas_size of the bounding boxes
(size of the corresponding image/video).
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.

Returns:
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes.
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ss

if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes):
if format is None or canvas_size is None:
raise ValueError(
"format and canvas_size cannot be None if bounding_boxes is a pure tensor. "
f"Got format={format} and canvas_size={canvas_size}."
"Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object."
)
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format.upper()]
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
)
bounding_boxes = bounding_boxes[valid]
Comment on lines +325 to +328
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit sad that these 2 lines are somewhat duplicated below in the else: block. I couldn't find a decent way to make all that work without upsetting torchscript.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's not bother.

else:
if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes):
raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.")
if format is not None or canvas_size is not None:
raise ValueError(
"format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. "
f"Got format={format} and canvas_size={canvas_size}. "
"Leave those to None or pass bouding_boxes as a pure tensor."
)
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
)
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)

return bounding_boxes, valid


def _get_sanitize_bounding_boxes_mask(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of this function is completely unchanged from the existing logic in the class.

bounding_boxes: torch.Tensor,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
min_size: float = 1.0,
) -> torch.Tensor:

bounding_boxes = _convert_bounding_box_format(
bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format
)

image_h, image_w = canvas_size
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = canvas_size
valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w)
valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h)
return valid