-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Allow SanitizeBoundingBoxes to sanitize more labels #8319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
ae2c114
WIP
NicolasHug 3a3fbf7
More stuff
NicolasHug 5e0fa5f
Address comments
NicolasHug 26ac1b0
Allow SanitizeBoundingBoxes to sanitize more labels
NicolasHug 8af950f
Merge branch 'main' of github.com:pytorch/vision into sanitize_labels…
NicolasHug 41a35a4
mypy?
NicolasHug 0f2e09e
Address comments
NicolasHug f96dbea
mypy
NicolasHug c9bd6c8
lint
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform): | |
- have any coordinate outside of their corresponding image. You may want to | ||
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. | ||
|
||
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO | ||
(see ``labels_getter`` parameter). | ||
|
||
It is recommended to call it at the end of a pipeline, before passing the | ||
input to the models. It is critical to call this transform if | ||
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called. | ||
|
@@ -330,12 +333,18 @@ class SanitizeBoundingBoxes(Transform): | |
|
||
Args: | ||
min_size (float, optional) The size below which bounding boxes are removed. Default is 1. | ||
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. | ||
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input | ||
(or anything else that needs to be sanitized along with the bounding boxes). | ||
By default, this will try to find a "labels" key in the input (case-insensitive), if | ||
the input is a dict or it is a tuple whose second element is a dict. | ||
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. | ||
It can also be a callable that takes the same input | ||
as the transform, and returns the labels. | ||
|
||
It can also be a callable that takes the same input as the transform, and returns either: | ||
|
||
- A single tensor (the labels) | ||
- A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes. | ||
This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties | ||
from COCO. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -356,18 +365,29 @@ def forward(self, *inputs: Any) -> Any: | |
inputs = inputs if len(inputs) > 1 else inputs[0] | ||
|
||
labels = self._labels_getter(inputs) | ||
if labels is not None and not isinstance(labels, torch.Tensor): | ||
raise ValueError( | ||
f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead." | ||
) | ||
if labels is not None: | ||
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead." | ||
if isinstance(labels, torch.Tensor): | ||
labels = (labels,) | ||
elif isinstance(labels, (tuple, list)): | ||
labels = tuple(labels) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for entry in labels: | ||
if not isinstance(entry, torch.Tensor): | ||
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checking: was this a requirement from the issue or a nice to have? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a nice to have |
||
raise ValueError(msg.format(type=type(entry))) | ||
else: | ||
raise ValueError(msg.format(type=type(labels))) | ||
|
||
flat_inputs, spec = tree_flatten(inputs) | ||
boxes = get_bounding_boxes(flat_inputs) | ||
|
||
if labels is not None and boxes.shape[0] != labels.shape[0]: | ||
raise ValueError( | ||
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." | ||
) | ||
if labels is not None: | ||
for label in labels: | ||
if boxes.shape[0] != label.shape[0]: | ||
raise ValueError( | ||
f"Number of boxes (shape={boxes.shape}) and must match the number of labels." | ||
f"Found labels with shape={label.shape})." | ||
) | ||
|
||
valid = F._misc._get_sanitize_bounding_boxes_mask( | ||
boxes, | ||
|
@@ -381,7 +401,7 @@ def forward(self, *inputs: Any) -> Any: | |
return tree_unflatten(flat_outputs, spec) | ||
|
||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||
is_label = inpt is not None and inpt is params["labels"] | ||
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) | ||
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) | ||
|
||
if not (is_label or is_bounding_boxes_or_mask): | ||
|
@@ -391,5 +411,5 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
|
||
if is_label: | ||
return output | ||
|
||
return tv_tensors.wrap(output, like=inpt) | ||
else: | ||
return tv_tensors.wrap(output, like=inpt) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.