Skip to content

Add SimpleCopyPaste augmentation #5825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
807f987
added simple POC
lezwon Apr 18, 2022
2fe16e8
added jitter and crop options
lezwon Apr 19, 2022
690f03f
added references
lezwon Apr 19, 2022
0055d83
moved simplecopypaste to detection module
lezwon Apr 21, 2022
5a6c263
working POC for simple copy paste in detection
lezwon Apr 22, 2022
7eefe7d
added comments
lezwon Apr 22, 2022
bdf20a0
remove transforms from class
lezwon May 4, 2022
f1ba6cf
removed loop for mask calculation
lezwon May 4, 2022
5b238cf
replaced Gaussian blur with functional api
lezwon May 10, 2022
7468480
added inplace operations
lezwon May 20, 2022
eb34465
added changes to accept tuples instead of tensors
lezwon May 20, 2022
7676203
- make copy paste functional
lezwon Jun 3, 2022
15bc8db
add inplace support within copy paste functional
lezwon Jun 3, 2022
d1f8361
Merge branch 'main' of github.com:pytorch/vision into transforms/simp…
vfdev-5 Jun 8, 2022
998034c
Merge branch 'main' of github.com:pytorch/vision into transforms/simp…
vfdev-5 Jun 10, 2022
c2a10a4
Updated code for copy-paste transform
vfdev-5 Jun 10, 2022
117c6da
Fixed code formatting
vfdev-5 Jun 10, 2022
21ac775
[skip ci] removed manual thresholding
vfdev-5 Jun 10, 2022
ad546fa
Replaced cropping by resizing data to paste
vfdev-5 Jun 10, 2022
064c44d
Removed inplace arg (as useless) and put a check on iscrowd target
vfdev-5 Jun 13, 2022
09b4db0
code-formatting
vfdev-5 Jun 13, 2022
43cae5f
Merge branch 'main' into transforms/simplecopypaste
vfdev-5 Jun 13, 2022
f1cc84b
Updated copypaste op to make it torch scriptable
vfdev-5 Jun 14, 2022
0eee0aa
Merge branch 'transforms/simplecopypaste' of github.com:lezwon/vision…
vfdev-5 Jun 14, 2022
a7b73b5
Merge branch 'main' into transforms/simplecopypaste
vfdev-5 Jun 14, 2022
956aa77
Fixed flake8
vfdev-5 Jun 14, 2022
7020eb8
Updates according to the review
vfdev-5 Jun 15, 2022
7f1de0d
Merge branch 'main' into transforms/simplecopypaste
datumbox Jun 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from coco_utils import get_coco, get_coco_kp
from engine import train_one_epoch, evaluate
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste


def get_dataset(name, image_set, transform, data_path):
Expand Down Expand Up @@ -145,6 +147,13 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

# Use CopyPaste augmentation training parameter
parser.add_argument(
"--use-copypaste",
action="store_true",
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)

return parser


Expand Down Expand Up @@ -180,8 +189,20 @@ def main(args):
else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")

copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)

def copypaste_collate_fn(batch):
return copypaste(*utils.collate_fn(batch))

train_collate_fn = copypaste_collate_fn

data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
Expand Down
155 changes: 155 additions & 0 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torchvision
from torch import nn, Tensor
from torchvision import ops
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T, InterpolationMode

Expand Down Expand Up @@ -437,3 +438,157 @@ def forward(
)

return image, target


def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:

# Random paste targets selection:
num_masks = len(paste_target["masks"])

if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target

# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)

paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]

masks = target["masks"]

# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)

paste_alpha_mask = paste_masks.sum(dim=0) > 0

if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)

# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}

out_target["masks"] = torch.cat([masks, paste_masks])

# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])

# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)

if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])

# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)

out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]

return image, out_target


class SimpleCopyPaste(torch.nn.Module):
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending

def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]

output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []

for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)

return output_images, output_targets

def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s