|
| 1 | +import collections |
1 | 2 | import warnings
|
2 |
| -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union |
| 3 | +from contextlib import suppress |
| 4 | +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union |
3 | 5 |
|
4 | 6 | import PIL.Image
|
5 | 7 |
|
6 | 8 | import torch
|
| 9 | +from torch.utils._pytree import tree_flatten, tree_unflatten |
7 | 10 |
|
8 | 11 | from torchvision import transforms as _transforms
|
9 |
| -from torchvision.ops import remove_small_boxes |
10 | 12 | from torchvision.prototype import datapoints
|
11 | 13 | from torchvision.prototype.transforms import functional as F, Transform
|
12 | 14 |
|
@@ -225,28 +227,113 @@ def _transform(
|
225 | 227 | return inpt.transpose(*dims)
|
226 | 228 |
|
227 | 229 |
|
228 |
| -class RemoveSmallBoundingBoxes(Transform): |
229 |
| - _transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) |
| 230 | +class SanitizeBoundingBoxes(Transform): |
| 231 | + # This removes boxes and their corresponding labels: |
| 232 | + # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) |
| 233 | + # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) |
230 | 234 |
|
231 |
| - def __init__(self, min_size: float = 1.0) -> None: |
| 235 | + def __init__( |
| 236 | + self, |
| 237 | + min_size: float = 1.0, |
| 238 | + labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", |
| 239 | + ) -> None: |
232 | 240 | super().__init__()
|
| 241 | + |
| 242 | + if min_size < 1: |
| 243 | + raise ValueError(f"min_size must be >= 1, got {min_size}.") |
233 | 244 | self.min_size = min_size
|
234 | 245 |
|
235 |
| - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
236 |
| - bounding_box = query_bounding_box(flat_inputs) |
237 |
| - |
238 |
| - # TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to |
239 |
| - # be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH |
240 |
| - # format,we need to convert first just to afterwards compute the width and height again, although they were |
241 |
| - # there in the first place for these formats. |
242 |
| - bounding_box = F.convert_format_bounding_box( |
243 |
| - bounding_box.as_subclass(torch.Tensor), |
244 |
| - old_format=bounding_box.format, |
245 |
| - new_format=datapoints.BoundingBoxFormat.XYXY, |
246 |
| - ) |
247 |
| - valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) |
| 246 | + self.labels_getter = labels_getter |
| 247 | + self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] |
| 248 | + if labels_getter == "default": |
| 249 | + self._labels_getter = self._find_labels_default_heuristic |
| 250 | + elif callable(labels_getter): |
| 251 | + self._labels_getter = labels_getter |
| 252 | + elif isinstance(labels_getter, str): |
| 253 | + self._labels_getter = lambda inputs: inputs[labels_getter] |
| 254 | + elif labels_getter is None: |
| 255 | + self._labels_getter = None |
| 256 | + else: |
| 257 | + raise ValueError( |
| 258 | + "labels_getter should either be a str, callable, or 'default'. " |
| 259 | + f"Got {labels_getter} of type {type(labels_getter)}." |
| 260 | + ) |
| 261 | + |
| 262 | + @staticmethod |
| 263 | + def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: |
| 264 | + # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive |
| 265 | + # Returns None if nothing is found |
| 266 | + candidate_key = None |
| 267 | + with suppress(StopIteration): |
| 268 | + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") |
| 269 | + if candidate_key is None: |
| 270 | + with suppress(StopIteration): |
| 271 | + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) |
| 272 | + if candidate_key is None: |
| 273 | + raise ValueError( |
| 274 | + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" |
| 275 | + "If there are no samples and it is by design, pass labels_getter=None." |
| 276 | + ) |
| 277 | + return inputs[candidate_key] |
| 278 | + |
| 279 | + def forward(self, *inputs: Any) -> Any: |
| 280 | + inputs = inputs if len(inputs) > 1 else inputs[0] |
| 281 | + |
| 282 | + if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): |
| 283 | + raise ValueError( |
| 284 | + f"If labels_getter is a str or 'default' (got {self.labels_getter}), " |
| 285 | + f"then the input to forward() must be a dict. Got {type(inputs)} instead." |
| 286 | + ) |
| 287 | + |
| 288 | + if self._labels_getter is None: |
| 289 | + labels = None |
| 290 | + else: |
| 291 | + labels = self._labels_getter(inputs) |
| 292 | + if labels is not None and not isinstance(labels, torch.Tensor): |
| 293 | + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") |
248 | 294 |
|
249 |
| - return dict(valid_indices=valid_indices) |
| 295 | + flat_inputs, spec = tree_flatten(inputs) |
| 296 | + # TODO: this enforces one single BoundingBox entry. |
| 297 | + # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... |
| 298 | + # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? |
| 299 | + boxes = query_bounding_box(flat_inputs) |
| 300 | + |
| 301 | + if boxes.ndim != 2: |
| 302 | + raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") |
| 303 | + |
| 304 | + if labels is not None and boxes.shape[0] != labels.shape[0]: |
| 305 | + raise ValueError( |
| 306 | + f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." |
| 307 | + ) |
| 308 | + |
| 309 | + boxes = cast( |
| 310 | + datapoints.BoundingBox, |
| 311 | + F.convert_format_bounding_box( |
| 312 | + boxes, |
| 313 | + new_format=datapoints.BoundingBoxFormat.XYXY, |
| 314 | + ), |
| 315 | + ) |
| 316 | + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] |
| 317 | + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) |
| 318 | + # TODO: Do we really need to check for out of bounds here? All |
| 319 | + # transforms should be clamping anyway, so this should never happen? |
| 320 | + image_h, image_w = boxes.spatial_size |
| 321 | + mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) |
| 322 | + mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) |
| 323 | + |
| 324 | + params = dict(mask=mask, labels=labels) |
| 325 | + flat_outputs = [ |
| 326 | + # Even-though it may look like we're transforming all inputs, we don't: |
| 327 | + # _transform() will only care about BoundingBoxes and the labels |
| 328 | + self._transform(inpt, params) |
| 329 | + for inpt in flat_inputs |
| 330 | + ] |
| 331 | + |
| 332 | + return tree_unflatten(flat_outputs, spec) |
250 | 333 |
|
251 | 334 | def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
252 |
| - return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) |
| 335 | + |
| 336 | + if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): |
| 337 | + inpt = inpt[params["mask"]] |
| 338 | + |
| 339 | + return inpt |
0 commit comments