-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add sanitize_bounding_boxes
kernel/functional
#8308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,7 @@ def _extract_image_targets( | |
if not (len(images) == len(bboxes) == len(masks) == len(labels)): | ||
raise TypeError( | ||
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " | ||
"BoundingBoxeses, Masks and Labels or OneHotLabels." | ||
"BoundingBoxes, Masks and Labels or OneHotLabels." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Driveby |
||
) | ||
|
||
targets = [] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import math | ||
from typing import List, Optional | ||
from typing import List, Optional, Tuple | ||
|
||
import PIL.Image | ||
import torch | ||
|
@@ -11,7 +11,9 @@ | |
|
||
from torchvision.utils import _log_api_usage_once | ||
|
||
from ._utils import _get_kernel, _register_kernel_internal | ||
from ._meta import _convert_bounding_box_format | ||
|
||
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor | ||
|
||
|
||
def normalize( | ||
|
@@ -275,3 +277,89 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: | |
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: | ||
# We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type | ||
return inpt.to(dtype) | ||
|
||
|
||
def sanitize_bounding_boxes( | ||
bounding_boxes: torch.Tensor, | ||
format: Optional[tv_tensors.BoundingBoxFormat] = None, | ||
canvas_size: Optional[Tuple[int, int]] = None, | ||
min_size: float = 1.0, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask. | ||
|
||
This removes bounding boxes that: | ||
|
||
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. | ||
- have any coordinate outside of their corresponding image. You may want to | ||
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals. | ||
|
||
It is recommended to call it at the end of a pipeline, before passing the | ||
input to the models. It is critical to call this transform if | ||
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called. | ||
If you want to be extra careful, you may call it after all transforms that | ||
may modify bounding boxes but once at the end should be enough in most | ||
cases. | ||
|
||
Args: | ||
bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized. | ||
format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes. | ||
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. | ||
canvas_size (tuple of int, optional): The canvas_size of the bounding boxes | ||
(size of the corresponding image/video). | ||
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. | ||
min_size (float, optional) The size below which bounding boxes are removed. Default is 1. | ||
|
||
Returns: | ||
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. | ||
The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): | ||
if format is None or canvas_size is None: | ||
raise ValueError( | ||
"format and canvas_size cannot be None if bounding_boxes is a pure tensor. " | ||
f"Got format={format} and canvas_size={canvas_size}." | ||
"Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." | ||
) | ||
if isinstance(format, str): | ||
format = tv_tensors.BoundingBoxFormat[format.upper()] | ||
valid = _get_sanitize_bounding_boxes_mask( | ||
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size | ||
) | ||
bounding_boxes = bounding_boxes[valid] | ||
Comment on lines
+325
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit sad that these 2 lines are somewhat duplicated below in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's not bother. |
||
else: | ||
if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): | ||
raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.") | ||
if format is not None or canvas_size is not None: | ||
raise ValueError( | ||
"format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " | ||
f"Got format={format} and canvas_size={canvas_size}. " | ||
"Leave those to None or pass bouding_boxes as a pure tensor." | ||
) | ||
valid = _get_sanitize_bounding_boxes_mask( | ||
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size | ||
) | ||
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) | ||
|
||
return bounding_boxes, valid | ||
|
||
|
||
def _get_sanitize_bounding_boxes_mask( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic of this function is completely unchanged from the existing logic in the class. |
||
bounding_boxes: torch.Tensor, | ||
format: tv_tensors.BoundingBoxFormat, | ||
canvas_size: Tuple[int, int], | ||
min_size: float = 1.0, | ||
) -> torch.Tensor: | ||
|
||
bounding_boxes = _convert_bounding_box_format( | ||
bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format | ||
) | ||
|
||
image_h, image_w = canvas_size | ||
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] | ||
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) | ||
# TODO: Do we really need to check for out of bounds here? All | ||
# transforms should be clamping anyway, so this should never happen? | ||
image_h, image_w = canvas_size | ||
valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) | ||
valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) | ||
return valid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That test did not change, I just moved the input generation above in a function so that it can be reused in the newly added tests below