-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add SanitizeBoundingBoxes transform #7246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7c5ab88
26929c0
9ae43b2
2ac342a
87a849e
7839dd8
57aaab3
e52ebb2
3a9619a
b093987
96ded4c
dbbebb7
4621097
ed030a5
85b96c4
66dcfc8
dda8810
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
import warnings | ||
from contextlib import suppress | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union | ||
|
||
import PIL.Image | ||
|
||
import torch | ||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
|
||
from torchvision import transforms as _transforms | ||
from torchvision.ops import remove_small_boxes | ||
from torchvision.prototype import datapoints | ||
from torchvision.prototype.transforms import functional as F, Transform | ||
|
||
|
@@ -225,28 +226,83 @@ def _transform( | |
return inpt.transpose(*dims) | ||
|
||
|
||
class RemoveSmallBoundingBoxes(Transform): | ||
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) | ||
class SanitizeBoundingBoxes(Transform): | ||
# This removes boxes and their corresponding labels: | ||
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) | ||
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size) | ||
_transformed_types = (datapoints.BoundingBox, datapoints.Mas) | ||
|
||
def __init__(self, min_size: float = 1.0) -> None: | ||
def __init__(self, min_size: float = 1.0, labels="default") -> None: | ||
super().__init__() | ||
self.min_size = min_size | ||
self.labels = labels | ||
|
||
def _find_label_default_heuristic(self, inputs): | ||
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive | ||
# Returns None if nothing is found | ||
candidate_key = None | ||
with suppress(StopIteration): | ||
candidate_key = next(key for key in inputs.keys() if key.lower() == "label") | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if candidate_key is None: | ||
with suppress(StopIteration): | ||
candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) | ||
labels = inputs.get(candidate_key) | ||
return labels | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
inputs = inputs if len(inputs) > 1 else inputs[0] | ||
if isinstance(labels, str) and not isinstance(inputs, dict): | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
f"If labels is a str or 'default' (got {labels}), then the input to forward() must be a dict. " | ||
f"Got {type(inputs)} instead" | ||
) | ||
|
||
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
bounding_box = query_bounding_box(flat_inputs) | ||
|
||
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to | ||
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH | ||
# format,we need to convert first just to afterwards compute the width and height again, although they were | ||
# there in the first place for these formats. | ||
bounding_box = F.convert_format_bounding_box( | ||
bounding_box.as_subclass(torch.Tensor), | ||
old_format=bounding_box.format, | ||
labels = None | ||
if self.labels == "default": | ||
labels = self._find_label_default_heuristic(inputs) | ||
elif callable(self.labels): | ||
labels = self.labels(inputs) | ||
elif isinstance(self.labels, str): | ||
labels = inputs[self.labels] | ||
else: | ||
raise ValueError( | ||
"labels parameter should either be a str, callable, or 'default'. " | ||
f"Got {labels} of type {type(labels)}." | ||
) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
flat_inputs, spec = tree_flatten(inputs) | ||
# TODO: this enforces one single BoundingBox entry. | ||
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... | ||
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? | ||
boxes = query_bounding_box(flat_inputs) | ||
|
||
boxes = F.convert_format_bounding_box( | ||
boxes, | ||
new_format=datapoints.BoundingBoxFormat.XYXY, | ||
) | ||
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) | ||
|
||
return dict(valid_indices=valid_indices) | ||
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
keep = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=1) | ||
# TODO: Do we really need to check for out of bounds here? All | ||
# transforms should be clamping anyway, so this should never happen? | ||
Comment on lines
+318
to
+319
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep this for now until we are sure about this, i.e. we have tests that guarantee this. Happy to remove if it turns out we don't need it. |
||
# TODO: Also... should this is <= instead of < ??? | ||
image_h, image_w = boxes.spatial_size | ||
keep &= (boxes[:, 0] < image_w).all() & (boxes[:, 2] < image_w).all() | ||
keep &= (boxes[:, 1] < image_h).all() & (boxes[:, 3] < image_h).all() | ||
valid_indices = torch.where(keep)[0] | ||
|
||
params = dict(valid_indices=valid_indices, labels=labels) | ||
flat_outputs = [ | ||
# Even-though it may look like we're transforming all inputs, we don't: | ||
# _transform() will only care about BoundingBoxes and the labels | ||
self._transform(inpt, params) | ||
for inpt in flat_inputs | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if we can do better without other changes, but this looks pretty weird. I mean, we have the bounding box and labels here. All we need to do is to put it at the right place in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yup... and I don't know how to do that easily :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean for boxes, we can change |
||
|
||
return tree_unflatten(flat_outputs, spec) | ||
|
||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||
return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) | ||
|
||
if inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox): | ||
inpt = inpt[params["valid_indices"]] | ||
|
||
return inpt |
Uh oh!
There was an error while loading. Please reload this page.