Skip to content

NEW Feature: Mixup transform for Object Detection #6721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
676a3ba
ADD: Empty file mixup.py for dummy PR
Oct 7, 2022
60cdf3b
ADD: Empty transform class
Oct 7, 2022
fd922ca
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Oct 22, 2022
728c7ca
WIP: Random Mixup for detection
Oct 28, 2022
3f204ac
Merge branch 'main' into 6720_add_mixup_transform
Nov 4, 2022
f1b70b9
First draft: Mixup detections
Nov 5, 2022
cdda41b
Fix: precommit issues
Nov 5, 2022
2d0765c
Fix: failing CI issues
Nov 6, 2022
7e82ff2
Fix: Tests and ADD: get_params and check_inputs functions
Nov 6, 2022
b83aedf
Fix: Remove usage of soon to be deprecated to_tensor function
Nov 6, 2022
50bea74
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Nov 6, 2022
90799b8
Remove: get params for mixup
Nov 6, 2022
248737d
Update _mixup_detection.py
ambujpawar Nov 7, 2022
26316a4
Remove unused type: ignore due to failing CI test
ambujpawar Nov 7, 2022
d7e08d2
Merge branch 'main' into 6720_add_mixup_transform
pmeier Nov 8, 2022
04c80d7
add batch detection helpers
pmeier Nov 8, 2022
5667c91
use helpers in detection mixup
pmeier Nov 8, 2022
e0724a3
refactor helpers
pmeier Nov 8, 2022
10c9033
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 1, 2022
6177057
revert accidental COCO change
pmeier Dec 1, 2022
2b67017
Move: mixup detection to _augment.py
Dec 4, 2022
fa0f54e
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Dec 4, 2022
cae66d9
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 6, 2022
ae9908b
refactor extraction and insertion
pmeier Dec 6, 2022
c2e2757
Fix: Failing SimpleCopyPaste and MixupDetection Failing tests
Dec 17, 2022
6b58135
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
5398c73
sample ratio in get_params
pmeier Dec 19, 2022
044ba0d
fix padding
pmeier Dec 19, 2022
884ace1
perform image conversion upfront
pmeier Dec 19, 2022
99de232
create base class
pmeier Dec 19, 2022
4ceef89
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
a6b9ae0
add shortcut for ratio==0
pmeier Dec 19, 2022
fce49b8
fix dtype
pmeier Dec 19, 2022
05c0491
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 17, 2023
d995471
Apply suggestions from code review
ambujpawar Jan 21, 2023
914a9ee
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 21, 2023
cbf09c2
Undo removing test_extract_image_target of TestSimpleCopyPaste
ambujpawar Jan 21, 2023
685d042
ADD: Test cases when mixup ratio is 0, 0.5, 1
ambujpawar Jan 22, 2023
3319215
Fix: was doing wrong asserts. Corrected it
ambujpawar Jan 22, 2023
02214b6
fix mixing
pmeier Jan 23, 2023
4486e78
pass flat_inputs to get_params
pmeier Jan 23, 2023
1b6dbe1
Update torchvision/prototype/transforms/_augment.py
ambujpawar Jan 23, 2023
8a912ba
refactor SimpleCopyPaste
pmeier Jan 23, 2023
ebd6bfd
Merge branch '6720_add_mixup_transform' of https://github.com/ambujpa…
pmeier Jan 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image, to_tensor

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

Expand Down Expand Up @@ -1918,3 +1918,80 @@ def test__transform(self, inpt):
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype


class TestMixupDetection:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)

def test__extract_image_targets_assertion(self, mocker):
transform = transforms.MixupDetection()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, features.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
]

with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
transform._extract_image_targets(flat_sample)

@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
transform = transforms.MixupDetection()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
# labels, bboxes
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
]

images, targets = transform._extract_image_targets(flat_sample)

assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]

def test__mixup(self):
image1 = 2 * torch.ones(3, 32, 64)
target_1 = {
"boxes": features.BoundingBox(
torch.tensor([[0.0, 0.0, 10.0, 10.0], [20.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=(32, 64),
),
"labels": features.Label(torch.tensor([1, 2])),
}

image2 = 10 * torch.ones(3, 64, 32)
target_2 = {
"boxes": features.BoundingBox(
torch.tensor([[10.0, 0.0, 20.0, 20.0], [10.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=(64, 32),
),
"labels": features.Label(torch.tensor([2, 3])),
}

transform = transforms.MixupDetection()
output_image, output_target = transform._mixup(image1, target_1, image2, target_2)
assert output_image.shape == (3, 64, 64)
assert output_target["boxes"].spatial_size == (64, 64)
assert len(output_target["boxes"]) == 4
assert len(output_target["labels"]) == 4
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ToDtype,
TransposeDimensions,
)
from ._mixup_detection import MixupDetection
from ._temporal import UniformTemporalSubsample
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

Expand Down
142 changes: 72 additions & 70 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw, query_spatial_size
from ._utils import _isinstance, has_any, query_chw, query_spatial_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -190,6 +190,40 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt


def flatten_and_extract(
inputs: Any, **types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]
) -> Tuple[Tuple[List[Any], TreeSpec, Dict[str, List[int]]], Dict[str, List[Any]]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])

idcs: Dict[str, List[int]] = {key: [] for key in types_or_checks.keys()}
inputs: Dict[str, List[Any]] = {key: [] for key in types_or_checks.keys()}
for idx, inpt in enumerate(flat_inputs):
for key, types_or_checks_ in types_or_checks.items():
if _isinstance(inpt, types_or_checks_):
inputs[key].append(inpt)
idcs[key].append(idx)
break

num_inputs = [len(inputs_) for inputs_ in inputs.values()]
if not all(num_inputs_ == num_inputs[0] for num_inputs_ in num_inputs[1:]):
raise TypeError("FIXME")

return (flat_inputs, spec, idcs), inputs


def unflatten_and_insert(
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, Dict[str, List[int]]],
outputs: Dict[str, List[Any]],
) -> Any:
flat_inputs, spec, idcs = flat_inputs_with_spec

for key, idcs_ in idcs.items():
for idx, output in zip(idcs_, outputs[key]):
flat_inputs[idx] = output

return tree_unflatten(flat_inputs, spec)


class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
Expand All @@ -205,15 +239,23 @@ def __init__(

def _copy_paste(
self,
image: features.TensorImageType,
image: features.ImageType,
target: Dict[str, Any],
paste_image: features.TensorImageType,
paste_image: features.ImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]:
) -> Tuple[features.ImageType, Dict[str, Any]]:
if isinstance(image, features.Image):
out_image = image.as_subclass(torch.Tensor)
paste_image = paste_image.as_subclass(torch.Tensor)
elif isinstance(image, PIL.Image.Image):
out_image = F.pil_to_tensor(image)
paste_image = F.pil_to_tensor(paste_image)
else: # features.is_simple_tensor(image)
out_image = image

paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
Expand All @@ -227,7 +269,7 @@ def _copy_paste(
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = cast(List[int], image.shape[-2:])
size1 = cast(List[int], out_image.shape[-2:])
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
Expand All @@ -241,7 +283,7 @@ def _copy_paste(

inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images:
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
out_image = out_image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))

# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
Expand Down Expand Up @@ -281,69 +323,28 @@ def _copy_paste(
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

return image, out_target

def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.Mask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)

if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Masks and Labels or OneHotLabels."
)
if isinstance(image, features.Image):
out_image = features.Image.wrap_like(image, out_image)
elif isinstance(image, PIL.Image.Image):
out_image = F.to_image_pil(out_image)

targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})
out_target["boxes"] = features.BoundingBox.wrap_like(target["boxes"], out_target["boxes"])
out_target["masks"] = features.Mask.wrap_like(target["masks"], out_target["masks"])
out_target["labels"] = features.Label.wrap_like(target["labels"], out_target["labels"])

return images, targets

def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.wrap_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif features.is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
return out_image, out_target

def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
flat_inputs_with_spec, inputs = flatten_and_extract(
inputs,
images=(features.Image, PIL.Image.Image, features.is_simple_tensor),
boxes=(features.BoundingBox,),
masks=(features.Mask,),
labels=(features.Label, features.OneHotLabel),
)

images, targets = self._extract_image_targets(flat_inputs)
images = inputs.pop("images")
targets = [dict(zip(inputs.keys(), target)) for target in zip(*inputs.values())]

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
Expand Down Expand Up @@ -380,7 +381,8 @@ def forward(self, *inputs: Any) -> Any:
output_images.append(output_image)
output_targets.append(output_target)

# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_inputs, output_images, output_targets)

return tree_unflatten(flat_inputs, spec)
outputs = dict(
dict(zip(output_targets[0].keys(), zip(*(list(target.values()) for target in output_targets)))),
images=images,
)
return unflatten_and_insert(flat_inputs_with_spec, outputs)
Loading