-
Notifications
You must be signed in to change notification settings - Fork 7.1k
NEW Feature: Mixup transform for Object Detection #6721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 33 commits
676a3ba
60cdf3b
fd922ca
728c7ca
3f204ac
f1b70b9
cdda41b
2d0765c
7e82ff2
b83aedf
50bea74
90799b8
248737d
26316a4
d7e08d2
04c80d7
5667c91
e0724a3
10c9033
6177057
2b67017
fa0f54e
cae66d9
ae9908b
c2e2757
6b58135
5398c73
044ba0d
884ace1
99de232
4ceef89
a6b9ae0
fce49b8
05c0491
d995471
914a9ee
cbf09c2
685d042
3319215
02214b6
4486e78
1b6dbe1
8a912ba
ebd6bfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,13 @@ | |
from typing import Any, cast, Dict, List, Optional, Tuple, Union | ||
|
||
import PIL.Image | ||
import torch | ||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
|
||
import torch | ||
from torchvision.ops import masks_to_boxes | ||
from torchvision.prototype import datapoints | ||
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform | ||
|
||
from ._transform import _RandomApplyTransform | ||
from ._transform import _DetectionBatchTransform, _RandomApplyTransform | ||
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size | ||
|
||
|
||
|
@@ -214,7 +213,6 @@ def _copy_paste( | |
resize_interpolation: F.InterpolationMode, | ||
antialias: Optional[bool], | ||
) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]: | ||
|
||
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) | ||
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) | ||
paste_labels = paste_target["labels"].wrap_like( | ||
|
@@ -241,7 +239,7 @@ def _copy_paste( | |
|
||
inverse_paste_alpha_mask = paste_alpha_mask.logical_not() | ||
# Copy-paste images: | ||
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) | ||
out_image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) | ||
|
||
# Copy-paste masks: | ||
masks = masks * inverse_paste_alpha_mask | ||
|
@@ -281,69 +279,15 @@ def _copy_paste( | |
out_target["masks"] = out_target["masks"][valid_targets] | ||
out_target["labels"] = out_target["labels"][valid_targets] | ||
|
||
return image, out_target | ||
|
||
def _extract_image_targets( | ||
self, flat_sample: List[Any] | ||
) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]: | ||
# fetch all images, bboxes, masks and labels from unstructured input | ||
# with List[image], List[BoundingBox], List[Mask], List[Label] | ||
images, bboxes, masks, labels = [], [], [], [] | ||
for obj in flat_sample: | ||
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): | ||
images.append(obj) | ||
elif isinstance(obj, PIL.Image.Image): | ||
images.append(F.to_image_tensor(obj)) | ||
elif isinstance(obj, datapoints.BoundingBox): | ||
bboxes.append(obj) | ||
elif isinstance(obj, datapoints.Mask): | ||
masks.append(obj) | ||
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): | ||
labels.append(obj) | ||
|
||
if not (len(images) == len(bboxes) == len(masks) == len(labels)): | ||
raise TypeError( | ||
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " | ||
"BoundingBoxes, Masks and Labels or OneHotLabels." | ||
) | ||
|
||
targets = [] | ||
for bbox, mask, label in zip(bboxes, masks, labels): | ||
targets.append({"boxes": bbox, "masks": mask, "labels": label}) | ||
|
||
return images, targets | ||
|
||
def _insert_outputs( | ||
self, | ||
flat_sample: List[Any], | ||
output_images: List[datapoints.TensorImageType], | ||
output_targets: List[Dict[str, Any]], | ||
) -> None: | ||
c0, c1, c2, c3 = 0, 0, 0, 0 | ||
for i, obj in enumerate(flat_sample): | ||
if isinstance(obj, datapoints.Image): | ||
flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) | ||
c0 += 1 | ||
elif isinstance(obj, PIL.Image.Image): | ||
flat_sample[i] = F.to_image_pil(output_images[c0]) | ||
c0 += 1 | ||
elif is_simple_tensor(obj): | ||
flat_sample[i] = output_images[c0] | ||
c0 += 1 | ||
elif isinstance(obj, datapoints.BoundingBox): | ||
flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) | ||
c1 += 1 | ||
elif isinstance(obj, datapoints.Mask): | ||
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) | ||
c2 += 1 | ||
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): | ||
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] | ||
c3 += 1 | ||
return out_image, out_target | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) | ||
|
||
images, targets = self._extract_image_targets(flat_inputs) | ||
flat_batch_with_spec, images, targets = flatten_and_extract_data( | ||
inputs, | ||
boxes=(datapoints.BoundingBox,), | ||
masks=(datapoints.Mask,), | ||
labels=(datapoints.Label, datapoints.OneHotLabel), | ||
) | ||
|
||
# images = [t1, t2, ..., tN] | ||
# Let's define paste_images as shifted list of input images | ||
|
@@ -380,7 +324,73 @@ def forward(self, *inputs: Any) -> Any: | |
output_images.append(output_image) | ||
output_targets.append(output_target) | ||
|
||
# Insert updated images and targets into input flat_sample | ||
self._insert_outputs(flat_inputs, output_images, output_targets) | ||
return unflatten_and_insert_data(flat_batch_with_spec, output_images, output_targets) | ||
|
||
|
||
class MixupDetection(_DetectionBatchTransform): | ||
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image) | ||
|
||
ambujpawar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
*, | ||
alpha: float = 1.5, | ||
) -> None: | ||
super().__init__() | ||
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) | ||
|
||
def _check_inputs(self, flat_inputs: List[Any]) -> None: | ||
if has_any(flat_inputs, datapoints.Mask, datapoints.Video): | ||
raise TypeError(f"{type(self).__name__}() is only supported for images and bounding boxes.") | ||
|
||
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
return dict(ratio=float(self._dist.sample())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've opted to sample the ratio in the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, it looks much tidier this way! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep it for consistency with the other transformations. This is the basic protocol for all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized we are actually not passing anything here. Let's just pass the flat inputs for completeness. |
||
|
||
def forward(self, *inputs: Any) -> Any: | ||
flat_batch_with_spec, batch = self._flatten_and_extract_data( | ||
inputs, | ||
image=(datapoints.Image, PIL.Image.Image, is_simple_tensor), | ||
boxes=(datapoints.BoundingBox,), | ||
labels=(datapoints.Label, datapoints.OneHotLabel), | ||
) | ||
self._check_inputs(flat_batch_with_spec[0]) | ||
|
||
batch = self._to_image_tensor(batch) | ||
ambujpawar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
batch_output = [ | ||
self._mixup(sample, sample_rolled, self._get_params([])["ratio"]) | ||
for sample, sample_rolled in zip(batch, batch[-1:] + batch[:-1]) | ||
] | ||
|
||
return self._unflatten_and_insert_data(flat_batch_with_spec, batch_output) | ||
|
||
def _mixup(self, sample_1: Dict[str, Any], sample_2: Dict[str, Any], ratio: float) -> Dict[str, Any]: | ||
if ratio >= 1.0: | ||
return sample_1 | ||
elif ratio == 0.0: | ||
return sample_2 | ||
|
||
h_1, w_1 = sample_1["image"].shape[-2:] | ||
h_2, w_2 = sample_2["image"].shape[-2:] | ||
h_mixup = max(h_1, h_2) | ||
w_mixup = max(w_1, w_2) | ||
|
||
# TODO: add the option to fill this with something else than 0 | ||
dtype = sample_1["image"].dtype if sample_1["image"].is_floating_point() else torch.float32 | ||
mix_image = F.pad_image_tensor( | ||
sample_1["image"].to(dtype), padding=[0, 0, w_mixup - w_1, h_mixup - h_1], fill=None | ||
).mul_(ratio) | ||
mix_image[..., :h_2, :w_2] = sample_2["image"] * (1.0 - ratio) | ||
ambujpawar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mix_image = mix_image.to(sample_1["image"]) | ||
ambujpawar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
mix_boxes = datapoints.BoundingBox.wrap_like( | ||
sample_1["boxes"], | ||
torch.cat([sample_1["boxes"], sample_2["boxes"]], dim=-2), | ||
spatial_size=(h_mixup, w_mixup), | ||
) | ||
|
||
mix_labels = datapoints.Label.wrap_like( | ||
sample_1["labels"], | ||
torch.cat([sample_1["labels"], sample_2["labels"]], dim=-1), | ||
) | ||
|
||
return tree_unflatten(flat_inputs, spec) | ||
return dict(image=mix_image, boxes=mix_boxes, labels=mix_labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you delete this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was using it to test
_extract_image_targets
function. However, since we removed those functions I removed them from here as wellThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry! Realized this was for TestSimpleCopyPaste. Undoing the changes, sorry