Skip to content

Allow SanitizeBoundingBoxes to sanitize more labels #8319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5706,7 +5706,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
return boxes, expected_valid_mask

@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize(
"labels_getter",
(
"default",
lambda inputs: inputs["labels"],
lambda inputs: (inputs["labels"], inputs["other_labels"]),
None,
lambda inputs: None,
),
)
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):

Expand All @@ -5721,12 +5730,16 @@ def test_transform(self, min_size, labels_getter, sample_type):

labels = torch.arange(boxes.shape[0])
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
# other_labels corresponds to properties from COCO like iscrowd, area...
# We only sanitize it when labels_getter returns a tuple
other_labels = torch.arange(boxes.shape[0])
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = {
"image": input_img,
"labels": labels,
"boxes": boxes,
"other_labels": other_labels,
"whatever": whatever,
"None": None,
"masks": masks,
Expand All @@ -5741,12 +5754,14 @@ def test_transform(self, min_size, labels_getter, sample_type):
if sample_type is tuple:
out_image = out[0]
out_labels = out[1]["labels"]
out_other_labels = out[1]["other_labels"]
out_boxes = out[1]["boxes"]
out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"]
else:
out_image = out["image"]
out_labels = out["labels"]
out_other_labels = out["other_labels"]
out_boxes = out["boxes"]
out_masks = out["masks"]
out_whatever = out["whatever"]
Expand All @@ -5757,14 +5772,20 @@ def test_transform(self, min_size, labels_getter, sample_type):
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
assert isinstance(out_masks, tv_tensors.Mask)

if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
assert out_labels is labels
assert out_other_labels is other_labels
else:
assert isinstance(out_labels, torch.Tensor)
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
# This works because we conveniently set labels to arange(num_boxes)
assert out_labels.tolist() == valid_indices

if callable(labels_getter) and type(labels_getter(sample)) is tuple:
assert_equal(out_other_labels, out_labels)
else:
assert_equal(out_other_labels, other_labels)

@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
def test_functional(self, input_type):
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
Expand Down
48 changes: 34 additions & 14 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform):
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.

It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
(see ``labels_getter`` parameter).

It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
Expand All @@ -330,12 +333,18 @@ class SanitizeBoundingBoxes(Transform):

Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if
the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
It can also be a callable that takes the same input
as the transform, and returns the labels.

It can also be a callable that takes the same input as the transform, and returns either:

- A single tensor (the labels)
- A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes.
This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties
from COCO.
"""

def __init__(
Expand All @@ -356,18 +365,29 @@ def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]

labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(
f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead."
)
if labels is not None:
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
if isinstance(labels, torch.Tensor):
labels = (labels,)
elif isinstance(labels, (tuple, list)):
labels = tuple(labels)
for entry in labels:
if not isinstance(entry, torch.Tensor):
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking: was this a requirement from the issue or a nice to have?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nice to have

raise ValueError(msg.format(type=type(entry)))
else:
raise ValueError(msg.format(type=type(labels)))

flat_inputs, spec = tree_flatten(inputs)
boxes = get_bounding_boxes(flat_inputs)

if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
if labels is not None:
for label in labels:
if boxes.shape[0] != label.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and must match the number of labels."
f"Found labels with shape={label.shape})."
)

valid = F._misc._get_sanitize_bounding_boxes_mask(
boxes,
Expand All @@ -381,7 +401,7 @@ def forward(self, *inputs: Any) -> Any:
return tree_unflatten(flat_outputs, spec)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))

if not (is_label or is_bounding_boxes_or_mask):
Expand All @@ -391,5 +411,5 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if is_label:
return output

return tv_tensors.wrap(output, like=inpt)
else:
return tv_tensors.wrap(output, like=inpt)