Skip to content

Add SanitizeBoundingBoxes transform #7246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
SanitizeBoundingBoxes,
ToDtype,
TransposeDimensions,
)
Expand Down
92 changes: 74 additions & 18 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import warnings
from contextlib import suppress
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image

import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchvision import transforms as _transforms
from torchvision.ops import remove_small_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand Down Expand Up @@ -225,28 +226,83 @@ def _transform(
return inpt.transpose(*dims)


class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
_transformed_types = (datapoints.BoundingBox, datapoints.Mas)

def __init__(self, min_size: float = 1.0) -> None:
def __init__(self, min_size: float = 1.0, labels="default") -> None:
super().__init__()
self.min_size = min_size
self.labels = labels

def _find_label_default_heuristic(self, inputs):
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "label")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
labels = inputs.get(candidate_key)
return labels

def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(labels, str) and not isinstance(inputs, dict):
raise ValueError(
f"If labels is a str or 'default' (got {labels}), then the input to forward() must be a dict. "
f"Got {type(inputs)} instead"
)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
bounding_box = query_bounding_box(flat_inputs)

# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
# format,we need to convert first just to afterwards compute the width and height again, although they were
# there in the first place for these formats.
bounding_box = F.convert_format_bounding_box(
bounding_box.as_subclass(torch.Tensor),
old_format=bounding_box.format,
labels = None
if self.labels == "default":
labels = self._find_label_default_heuristic(inputs)
elif callable(self.labels):
labels = self.labels(inputs)
elif isinstance(self.labels, str):
labels = inputs[self.labels]
else:
raise ValueError(
"labels parameter should either be a str, callable, or 'default'. "
f"Got {labels} of type {type(labels)}."
)

flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)

boxes = F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
)
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)

return dict(valid_indices=valid_indices)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
Comment on lines +318 to +319
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this for now until we are sure about this, i.e. we have tests that guarantee this. Happy to remove if it turns out we don't need it.

# TODO: Also... should this is <= instead of < ???
image_h, image_w = boxes.spatial_size
keep &= (boxes[:, 0] < image_w).all() & (boxes[:, 2] < image_w).all()
keep &= (boxes[:, 1] < image_h).all() & (boxes[:, 3] < image_h).all()
valid_indices = torch.where(keep)[0]

params = dict(valid_indices=valid_indices, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can do better without other changes, but this looks pretty weird. I mean, we have the bounding box and labels here. All we need to do is to put it at the right place in flat_inputs and we should be good to go without going the extra mile through self._transform.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All we need to do is to put it at the right place in flat_inputs

yup... and I don't know how to do that easily :)
But if you can find a way to bypass _transforms(), I'm all ears

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean for boxes, we can change query_bounding_box to whatever we like. Meaning, we could return the index from there. For the labels the story is different. We can't pass the flat_inputs because we rely on the dict keys. Meaning, users would need to return a spec similar to what tree_flatten produces, but that is bad UX. 🤷


return tree_unflatten(flat_outputs, spec)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.wrap_like(inpt, inpt[params["valid_indices"]])

if inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["valid_indices"]]

return inpt