From c19f838a9d855fe3eadb6b36aef2387343c65466 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 10:32:26 +0200 Subject: [PATCH 01/49] use prototype transforms in classification reference --- references/classification/presets.py | 14 +- references/classification/train.py | 31 +++- references/classification/transforms.py | 183 ------------------- torchvision/prototype/transforms/__init__.py | 2 +- 4 files changed, 30 insertions(+), 200 deletions(-) delete mode 100644 references/classification/transforms.py diff --git a/references/classification/presets.py b/references/classification/presets.py index 6bc38e72953..3a495763f9e 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,5 +1,5 @@ import torch -from torchvision.transforms import autoaugment, transforms +from torchvision.prototype import transforms from torchvision.transforms.functional import InterpolationMode @@ -17,17 +17,17 @@ def __init__( ): trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation)) + trans.append(transforms.RandAugment(interpolation=interpolation)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + trans.append(transforms.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation)) + trans.append(transforms.AugMix(interpolation=interpolation)) else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy) + trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ transforms.PILToTensor(), diff --git a/references/classification/train.py b/references/classification/train.py index 14360b042ed..ad3db9bafa0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -7,14 +7,19 @@ import torch import torch.utils.data import torchvision -import transforms import utils from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate +from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode +class WrapTarget(nn.Module): + def forward(self, input, target): + return input, features.Label(target) + + def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -128,12 +133,13 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, - presets.ClassificationPresetTrain( + transform=presets.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), + target_transform=lambda target: features.Label(target), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") @@ -158,7 +164,8 @@ def load_data(traindir, valdir, args): dataset_test = torchvision.datasets.ImageFolder( valdir, - preprocessing, + transform=preprocessing, + target_transform=lambda target: features.Label(target), ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") @@ -200,14 +207,20 @@ def main(args): collate_fn = None num_classes = len(dataset.classes) - mixup_transforms = [] + mixup_or_cutmix = [] if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha)) if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) - if mixup_transforms: - mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) - collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 + mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha)) + if mixup_or_cutmix: + batch_transform = transforms.Compose( + [ + WrapTarget(), + transforms.LabelToOneHot(num_categories=num_classes), + transforms.RandomChoice(*mixup_or_cutmix), + ] + ) + collate_fn = lambda batch: batch_transform(default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, diff --git a/references/classification/transforms.py b/references/classification/transforms.py deleted file mode 100644 index 9a8ef7877d6..00000000000 --- a/references/classification/transforms.py +++ /dev/null @@ -1,183 +0,0 @@ -import math -from typing import Tuple - -import torch -from torch import Tensor -from torchvision.transforms import functional as F - - -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"mixup: Beyond Empirical Risk Minimization" `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for mixup. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - - if num_classes < 1: - raise ValueError( - f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" - ) - - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on mixup paper, page 3. - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s - - -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" - `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for cutmix. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - if num_classes < 1: - raise ValueError("Please provide a valid positive value for the num_classes.") - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - _, H, W = F.get_dimensions(batch) - - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) - - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index dc6476ab4b5..ca89fee918a 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -3,7 +3,7 @@ from ._transform import Transform # usort: skip from ._augment import RandomCutmix, RandomErasing, RandomMixup -from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, RandomAdjustSharpness, From 7b7602e670d5488298bde024df7f97c6ea99d859 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 Aug 2022 10:50:43 +0200 Subject: [PATCH 02/49] cleanup --- references/classification/presets.py | 4 ++-- references/classification/train.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 3a495763f9e..816f033479d 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -30,7 +30,7 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ - transforms.PILToTensor(), + transforms.ToImageTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -59,7 +59,7 @@ def __init__( [ transforms.Resize(resize_size, interpolation=interpolation), transforms.CenterCrop(crop_size), - transforms.PILToTensor(), + transforms.ToImageTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] diff --git a/references/classification/train.py b/references/classification/train.py index ad3db9bafa0..c0d6a91364b 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -15,9 +15,10 @@ from torchvision.transforms.functional import InterpolationMode -class WrapTarget(nn.Module): - def forward(self, input, target): - return input, features.Label(target) +class WrapIntoFeatures(nn.Module): + def forward(self, sample): + input, target = sample + return features.Image(input), features.Label(target) def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): @@ -215,8 +216,9 @@ def main(args): if mixup_or_cutmix: batch_transform = transforms.Compose( [ - WrapTarget(), + WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), + transforms.ToDtype(torch.float, features.OneHotLabel), transforms.RandomChoice(*mixup_or_cutmix), ] ) From 4990b894a32261165d594f8fd927bca2ada714ee Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 09:57:28 +0200 Subject: [PATCH 03/49] move WrapIntoFeatures into transforms module --- references/classification/train.py | 13 +++++-------- references/classification/transforms.py | 8 ++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 references/classification/transforms.py diff --git a/references/classification/train.py b/references/classification/train.py index c0d6a91364b..69e38b7d8a8 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -4,23 +4,20 @@ import warnings import presets +from sampler import RASampler +from transforms import WrapIntoFeatures +import utils # usort: skip + import torch import torch.utils.data import torchvision -import utils -from sampler import RASampler + from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode -class WrapIntoFeatures(nn.Module): - def forward(self, sample): - input, target = sample - return features.Image(input), features.Label(target) - - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") diff --git a/references/classification/transforms.py b/references/classification/transforms.py new file mode 100644 index 00000000000..2438bc45730 --- /dev/null +++ b/references/classification/transforms.py @@ -0,0 +1,8 @@ +from torch import nn +from torchvision.prototype import features + + +class WrapIntoFeatures(nn.Module): + def forward(self, sample): + input, target = sample + return features.Image(input), features.Label(target) From ca4c5a7dd8e23d2bc6fac42f3c4ff7b397255a52 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 15:49:29 +0200 Subject: [PATCH 04/49] [skip ci] add p=1.0 to CutMix and MixUp --- references/classification/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 69e38b7d8a8..8745c9222d0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -207,9 +207,9 @@ def main(args): num_classes = len(dataset.classes) mixup_or_cutmix = [] if args.mixup_alpha > 0.0: - mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha)) + mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0)) if args.cutmix_alpha > 0.0: - mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha)) + mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0)) if mixup_or_cutmix: batch_transform = transforms.Compose( [ From 693795e31f1b8a554408c31285bca4a74c79d0d0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 23 Aug 2022 16:06:27 +0200 Subject: [PATCH 05/49] [skip ci] From fe96a5439e4cd4f7327f44a1f5c62c9c9c2ba951 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 24 Aug 2022 14:13:53 +0200 Subject: [PATCH 06/49] use prototype transforms in detection references --- references/detection/coco_utils.py | 54 +-- references/detection/presets.py | 92 ++--- references/detection/train.py | 4 +- references/detection/transforms.py | 593 ----------------------------- 4 files changed, 69 insertions(+), 674 deletions(-) delete mode 100644 references/detection/transforms.py diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 396de63297b..ee0a5fdd99a 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,30 +1,12 @@ -import copy import os import torch import torch.utils.data import torchvision -import transforms as T + from pycocotools import mask as coco_mask from pycocotools.coco import COCO - - -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target +from torchvision.prototype import features, transforms as T def convert_coco_poly_to_mask(segmentations, height, width): @@ -45,7 +27,8 @@ def convert_coco_poly_to_mask(segmentations, height, width): class ConvertCocoPolysToMask: - def __call__(self, image, target): + def __call__(self, sample): + image, target = sample w, h = image.size image_id = target["image_id"] @@ -100,6 +83,27 @@ def __call__(self, image, target): return image, target +class WrapIntoFeatures: + def __call__(self, sample): + image, target = sample + + wrapped_target = dict( + boxes=features.BoundingBox( + target["boxes"], + format=features.BoundingBoxFormat.XYXY, + image_size=(image.height, image.width), + ), + # TODO: add categories + labels=features.Label(target["labels"], categories=None), + masks=features.SegmentationMask(target["masks"]), + image_id=int(target["image_id"]), + area=target["area"].tolist(), + iscrowd=target["iscrowd"].to(torch.bool).tolist(), + ) + + return image, wrapped_target + + def _coco_remove_images_without_annotations(dataset, cat_list=None): def _has_only_empty_bbox(anno): return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) @@ -225,10 +229,12 @@ def get_coco(root, image_set, transforms, mode="instances"): PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), - # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) } - t = [ConvertCocoPolysToMask()] + t = [ + ConvertCocoPolysToMask(), + WrapIntoFeatures(), + ] if transforms is not None: t.append(transforms) @@ -243,8 +249,6 @@ def get_coco(root, image_set, transforms, mode="instances"): if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) - # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) - return dataset diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..af4dcf72bfe 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,73 +1,57 @@ import torch -import transforms as T +from torchvision.prototype import transforms as T -class DetectionPresetTrain: +class DetectionPresetTrain(T.Compose): def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): if data_augmentation == "hflip": - self.transforms = T.Compose( - [ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomHorizontalFlip(p=hflip_prob), + T.ToImageTensor(), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "lsj": - self.transforms = T.Compose( - [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.ScaleJitter(target_size=(1024, 1024)), + T.FixedSizeCrop(size=(1024, 1024), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToImageTensor(), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "multiscale": - self.transforms = T.Compose( - [ - T.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToImageTensor(), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "ssd": - self.transforms = T.Compose( - [ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=list(mean)), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToImageTensor(), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "ssdlite": - self.transforms = T.Compose( - [ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ToImageTensor(), + T.ConvertImageDtype(torch.float), + ] else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') - def __call__(self, img, target): - return self.transforms(img, target) + super().__init__(transforms) -class DetectionPresetEval: +class DetectionPresetEval(T.Compose): def __init__(self): - self.transforms = T.Compose( + super().__init__( [ - T.PILToTensor(), + T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] ) - - def __call__(self, img, target): - return self.transforms(img, target) diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..0662adcd173 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -31,12 +31,12 @@ from coco_utils import get_coco, get_coco_kp from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler +from torchvision.prototype import transforms as T from torchvision.transforms import InterpolationMode -from transforms import SimpleCopyPaste def copypaste_collate_fn(batch): - copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) + copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) return copypaste(*utils.collate_fn(batch)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py deleted file mode 100644 index d26bf6eac85..00000000000 --- a/references/detection/transforms.py +++ /dev/null @@ -1,593 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torchvision -from torch import nn, Tensor -from torchvision import ops -from torchvision.transforms import functional as F, InterpolationMode, transforms as T - - -def _flip_coco_person_keypoints(kps, width): - flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] - flipped_data = kps[:, flip_inds] - flipped_data[..., 0] = width - flipped_data[..., 0] - # Maintain COCO convention that if visibility == 0, then x, y = 0 - inds = flipped_data[..., 2] == 0 - flipped_data[inds] = 0 - return flipped_data - - -class Compose: - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, image, target): - for t in self.transforms: - image, target = t(image, target) - return image, target - - -class RandomHorizontalFlip(T.RandomHorizontalFlip): - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - if torch.rand(1) < self.p: - image = F.hflip(image) - if target is not None: - _, _, width = F.get_dimensions(image) - target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] - if "masks" in target: - target["masks"] = target["masks"].flip(-1) - if "keypoints" in target: - keypoints = target["keypoints"] - keypoints = _flip_coco_person_keypoints(keypoints, width) - target["keypoints"] = keypoints - return image, target - - -class PILToTensor(nn.Module): - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - image = F.pil_to_tensor(image) - return image, target - - -class ConvertImageDtype(nn.Module): - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() - self.dtype = dtype - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - image = F.convert_image_dtype(image, self.dtype) - return image, target - - -class RandomIoUCrop(nn.Module): - def __init__( - self, - min_scale: float = 0.3, - max_scale: float = 1.0, - min_aspect_ratio: float = 0.5, - max_aspect_ratio: float = 2.0, - sampler_options: Optional[List[float]] = None, - trials: int = 40, - ): - super().__init__() - # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 - self.min_scale = min_scale - self.max_scale = max_scale - self.min_aspect_ratio = min_aspect_ratio - self.max_aspect_ratio = max_aspect_ratio - if sampler_options is None: - sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] - self.options = sampler_options - self.trials = trials - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - if target is None: - raise ValueError("The targets can't be None for this transform.") - - if isinstance(image, torch.Tensor): - if image.ndimension() not in {2, 3}: - raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") - elif image.ndimension() == 2: - image = image.unsqueeze(0) - - _, orig_h, orig_w = F.get_dimensions(image) - - while True: - # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) - min_jaccard_overlap = self.options[idx] - if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option - return image, target - - for _ in range(self.trials): - # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) - new_w = int(orig_w * r[0]) - new_h = int(orig_h * r[1]) - aspect_ratio = new_w / new_h - if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): - continue - - # check for 0 area crops - r = torch.rand(2) - left = int((orig_w - new_w) * r[0]) - top = int((orig_h - new_h) * r[1]) - right = left + new_w - bottom = top + new_h - if left == right or top == bottom: - continue - - # check for any valid boxes with centers within the crop area - cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2]) - cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3]) - is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) - if not is_within_crop_area.any(): - continue - - # check at least 1 box with jaccard limitations - boxes = target["boxes"][is_within_crop_area] - ious = torchvision.ops.boxes.box_iou( - boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device) - ) - if ious.max() < min_jaccard_overlap: - continue - - # keep only valid boxes and perform cropping - target["boxes"] = boxes - target["labels"] = target["labels"][is_within_crop_area] - target["boxes"][:, 0::2] -= left - target["boxes"][:, 1::2] -= top - target["boxes"][:, 0::2].clamp_(min=0, max=new_w) - target["boxes"][:, 1::2].clamp_(min=0, max=new_h) - image = F.crop(image, top, left, new_h, new_w) - - return image, target - - -class RandomZoomOut(nn.Module): - def __init__( - self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 - ): - super().__init__() - if fill is None: - fill = [0.0, 0.0, 0.0] - self.fill = fill - self.side_range = side_range - if side_range[0] < 1.0 or side_range[0] > side_range[1]: - raise ValueError(f"Invalid canvas side range provided {side_range}.") - self.p = p - - @torch.jit.unused - def _get_fill_value(self, is_pil): - # type: (bool) -> int - # We fake the type to make it work on JIT - return tuple(int(x) for x in self.fill) if is_pil else 0 - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - if isinstance(image, torch.Tensor): - if image.ndimension() not in {2, 3}: - raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") - elif image.ndimension() == 2: - image = image.unsqueeze(0) - - if torch.rand(1) >= self.p: - return image, target - - _, orig_h, orig_w = F.get_dimensions(image) - - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) - canvas_width = int(orig_w * r) - canvas_height = int(orig_h * r) - - r = torch.rand(2) - left = int((canvas_width - orig_w) * r[0]) - top = int((canvas_height - orig_h) * r[1]) - right = canvas_width - (left + orig_w) - bottom = canvas_height - (top + orig_h) - - if torch.jit.is_scripting(): - fill = 0 - else: - fill = self._get_fill_value(F._is_pil_image(image)) - - image = F.pad(image, [left, top, right, bottom], fill=fill) - if isinstance(image, torch.Tensor): - # PyTorch's pad supports only integers on fill. So we need to overwrite the colour - v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) - image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[ - ..., :, (left + orig_w) : - ] = v - - if target is not None: - target["boxes"][:, 0::2] += left - target["boxes"][:, 1::2] += top - - return image, target - - -class RandomPhotometricDistort(nn.Module): - def __init__( - self, - contrast: Tuple[float, float] = (0.5, 1.5), - saturation: Tuple[float, float] = (0.5, 1.5), - hue: Tuple[float, float] = (-0.05, 0.05), - brightness: Tuple[float, float] = (0.875, 1.125), - p: float = 0.5, - ): - super().__init__() - self._brightness = T.ColorJitter(brightness=brightness) - self._contrast = T.ColorJitter(contrast=contrast) - self._hue = T.ColorJitter(hue=hue) - self._saturation = T.ColorJitter(saturation=saturation) - self.p = p - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - if isinstance(image, torch.Tensor): - if image.ndimension() not in {2, 3}: - raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") - elif image.ndimension() == 2: - image = image.unsqueeze(0) - - r = torch.rand(7) - - if r[0] < self.p: - image = self._brightness(image) - - contrast_before = r[1] < 0.5 - if contrast_before: - if r[2] < self.p: - image = self._contrast(image) - - if r[3] < self.p: - image = self._saturation(image) - - if r[4] < self.p: - image = self._hue(image) - - if not contrast_before: - if r[5] < self.p: - image = self._contrast(image) - - if r[6] < self.p: - channels, _, _ = F.get_dimensions(image) - permutation = torch.randperm(channels) - - is_pil = F._is_pil_image(image) - if is_pil: - image = F.pil_to_tensor(image) - image = F.convert_image_dtype(image) - image = image[..., permutation, :, :] - if is_pil: - image = F.to_pil_image(image) - - return image, target - - -class ScaleJitter(nn.Module): - """Randomly resizes the image and its bounding boxes within the specified scale range. - The class implements the Scale Jitter augmentation as described in the paper - `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. - - Args: - target_size (tuple of ints): The target size for the transform provided in (height, weight) format. - scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the - range a <= scale <= b. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - """ - - def __init__( - self, - target_size: Tuple[int, int], - scale_range: Tuple[float, float] = (0.1, 2.0), - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - ): - super().__init__() - self.target_size = target_size - self.scale_range = scale_range - self.interpolation = interpolation - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - if isinstance(image, torch.Tensor): - if image.ndimension() not in {2, 3}: - raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") - elif image.ndimension() == 2: - image = image.unsqueeze(0) - - _, orig_height, orig_width = F.get_dimensions(image) - - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) - r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale - new_width = int(orig_width * r) - new_height = int(orig_height * r) - - image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) - - if target is not None: - target["boxes"][:, 0::2] *= new_width / orig_width - target["boxes"][:, 1::2] *= new_height / orig_height - if "masks" in target: - target["masks"] = F.resize( - target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST - ) - - return image, target - - -class FixedSizeCrop(nn.Module): - def __init__(self, size, fill=0, padding_mode="constant"): - super().__init__() - size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) - self.crop_height = size[0] - self.crop_width = size[1] - self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. - self.padding_mode = padding_mode - - def _pad(self, img, target, padding): - # Taken from the functional_tensor.py pad - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - elif len(padding) == 1: - pad_left = pad_right = pad_top = pad_bottom = padding[0] - elif len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - else: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - - padding = [pad_left, pad_top, pad_right, pad_bottom] - img = F.pad(img, padding, self.fill, self.padding_mode) - if target is not None: - target["boxes"][:, 0::2] += pad_left - target["boxes"][:, 1::2] += pad_top - if "masks" in target: - target["masks"] = F.pad(target["masks"], padding, 0, "constant") - - return img, target - - def _crop(self, img, target, top, left, height, width): - img = F.crop(img, top, left, height, width) - if target is not None: - boxes = target["boxes"] - boxes[:, 0::2] -= left - boxes[:, 1::2] -= top - boxes[:, 0::2].clamp_(min=0, max=width) - boxes[:, 1::2].clamp_(min=0, max=height) - - is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3]) - - target["boxes"] = boxes[is_valid] - target["labels"] = target["labels"][is_valid] - if "masks" in target: - target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width) - - return img, target - - def forward(self, img, target=None): - _, height, width = F.get_dimensions(img) - new_height = min(height, self.crop_height) - new_width = min(width, self.crop_width) - - if new_height != height or new_width != width: - offset_height = max(height - self.crop_height, 0) - offset_width = max(width - self.crop_width, 0) - - r = torch.rand(1) - top = int(offset_height * r) - left = int(offset_width * r) - - img, target = self._crop(img, target, top, left, new_height, new_width) - - pad_bottom = max(self.crop_height - new_height, 0) - pad_right = max(self.crop_width - new_width, 0) - if pad_bottom != 0 or pad_right != 0: - img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) - - return img, target - - -class RandomShortestSize(nn.Module): - def __init__( - self, - min_size: Union[List[int], Tuple[int], int], - max_size: int, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - ): - super().__init__() - self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) - self.max_size = max_size - self.interpolation = interpolation - - def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - _, orig_height, orig_width = F.get_dimensions(image) - - min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] - r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) - - new_width = int(orig_width * r) - new_height = int(orig_height * r) - - image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) - - if target is not None: - target["boxes"][:, 0::2] *= new_width / orig_width - target["boxes"][:, 1::2] *= new_height / orig_height - if "masks" in target: - target["masks"] = F.resize( - target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST - ) - - return image, target - - -def _copy_paste( - image: torch.Tensor, - target: Dict[str, Tensor], - paste_image: torch.Tensor, - paste_target: Dict[str, Tensor], - blending: bool = True, - resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, -) -> Tuple[torch.Tensor, Dict[str, Tensor]]: - - # Random paste targets selection: - num_masks = len(paste_target["masks"]) - - if num_masks < 1: - # Such degerante case with num_masks=0 can happen with LSJ - # Let's just return (image, target) - return image, target - - # We have to please torch script by explicitly specifying dtype as torch.long - random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) - random_selection = torch.unique(random_selection).to(torch.long) - - paste_masks = paste_target["masks"][random_selection] - paste_boxes = paste_target["boxes"][random_selection] - paste_labels = paste_target["labels"][random_selection] - - masks = target["masks"] - - # We resize source and paste data if they have different sizes - # This is something we introduced here as originally the algorithm works - # on equal-sized data (for example, coming from LSJ data augmentations) - size1 = image.shape[-2:] - size2 = paste_image.shape[-2:] - if size1 != size2: - paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) - paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) - # resize bboxes: - ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) - paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) - - paste_alpha_mask = paste_masks.sum(dim=0) > 0 - - if blending: - paste_alpha_mask = F.gaussian_blur( - paste_alpha_mask.unsqueeze(0), - kernel_size=(5, 5), - sigma=[ - 2.0, - ], - ) - - # Copy-paste images: - image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) - - # Copy-paste masks: - masks = masks * (~paste_alpha_mask) - non_all_zero_masks = masks.sum((-1, -2)) > 0 - masks = masks[non_all_zero_masks] - - # Do a shallow copy of the target dict - out_target = {k: v for k, v in target.items()} - - out_target["masks"] = torch.cat([masks, paste_masks]) - - # Copy-paste boxes and labels - boxes = ops.masks_to_boxes(masks) - out_target["boxes"] = torch.cat([boxes, paste_boxes]) - - labels = target["labels"][non_all_zero_masks] - out_target["labels"] = torch.cat([labels, paste_labels]) - - # Update additional optional keys: area and iscrowd if exist - if "area" in target: - out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) - - if "iscrowd" in target and "iscrowd" in paste_target: - # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) - # For example, if previous transforms geometrically modifies masks/boxes/labels but - # does not update "iscrowd" - if len(target["iscrowd"]) == len(non_all_zero_masks): - iscrowd = target["iscrowd"][non_all_zero_masks] - paste_iscrowd = paste_target["iscrowd"][random_selection] - out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) - - # Check for degenerated boxes and remove them - boxes = out_target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - valid_targets = ~degenerate_boxes.any(dim=1) - - out_target["boxes"] = boxes[valid_targets] - out_target["masks"] = out_target["masks"][valid_targets] - out_target["labels"] = out_target["labels"][valid_targets] - - if "area" in out_target: - out_target["area"] = out_target["area"][valid_targets] - if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): - out_target["iscrowd"] = out_target["iscrowd"][valid_targets] - - return image, out_target - - -class SimpleCopyPaste(torch.nn.Module): - def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR): - super().__init__() - self.resize_interpolation = resize_interpolation - self.blending = blending - - def forward( - self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]] - ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]: - torch._assert( - isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]), - "images should be a list of tensors", - ) - torch._assert( - isinstance(targets, (list, tuple)) and len(images) == len(targets), - "targets should be a list of the same size as images", - ) - for target in targets: - # Can not check for instance type dict with inside torch.jit.script - # torch._assert(isinstance(target, dict), "targets item should be a dict") - for k in ["masks", "boxes", "labels"]: - torch._assert(k in target, f"Key {k} should be present in targets") - torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor") - - # images = [t1, t2, ..., tN] - # Let's define paste_images as shifted list of input images - # paste_images = [t2, t3, ..., tN, t1] - # FYI: in TF they mix data on the dataset level - images_rolled = images[-1:] + images[:-1] - targets_rolled = targets[-1:] + targets[:-1] - - output_images: List[torch.Tensor] = [] - output_targets: List[Dict[str, Tensor]] = [] - - for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): - output_image, output_data = _copy_paste( - image, - target, - paste_image, - paste_target, - blending=self.blending, - resize_interpolation=self.resize_interpolation, - ) - output_images.append(output_image) - output_targets.append(output_data) - - return output_images, output_targets - - def __repr__(self) -> str: - s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})" - return s From 6fd5e5013d8cab3033e7e7c01cc9cab99787b9b7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 26 Aug 2022 16:38:33 +0200 Subject: [PATCH 07/49] [skip ci] From 6fcffb2a6fe0697ee2f5b6aa35518f34b5d28dde Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 30 Aug 2022 14:55:18 +0200 Subject: [PATCH 08/49] [skip ci] From 7cb08d51d50c46b2cce4fa98480fed9414447308 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 1 Sep 2022 09:56:27 +0200 Subject: [PATCH 09/49] [skip ci] fix scripts --- references/classification/train.py | 2 +- references/detection/coco_utils.py | 2 +- references/detection/engine.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 8745c9222d0..405cc571ebd 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -216,7 +216,7 @@ def main(args): WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), transforms.ToDtype(torch.float, features.OneHotLabel), - transforms.RandomChoice(*mixup_or_cutmix), + transforms.RandomChoice(mixup_or_cutmix), ] ) collate_fn = lambda batch: batch_transform(default_collate(batch)) # noqa: E731 diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index ee0a5fdd99a..df2666ebd40 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -98,7 +98,7 @@ def __call__(self, sample): masks=features.SegmentationMask(target["masks"]), image_id=int(target["image_id"]), area=target["area"].tolist(), - iscrowd=target["iscrowd"].to(torch.bool).tolist(), + iscrowd=target["iscrowd"].bool().tolist(), ) return image, wrapped_target diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e5d55f189d..0e9bfffdf8a 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device): outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {target["image_id"]: output for target, output in zip(targets, outputs)} evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time From a98c05d2f6dfb81b85f2523ca8d86d644249c549 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 7 Sep 2022 17:52:28 +0200 Subject: [PATCH 10/49] [SKIP CI] CircleCI From 49e653f16a1b9bf35915be7de8670ad3077a0291 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 7 Sep 2022 17:52:36 +0200 Subject: [PATCH 11/49] [skip ci] From 8df904353812eb0baa02b7babc9c0841e8b8e296 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 13 Sep 2022 17:14:43 +0200 Subject: [PATCH 12/49] update segmentation references --- references/segmentation/coco_utils.py | 10 ++- references/segmentation/presets.py | 43 ++++----- references/segmentation/train.py | 7 +- references/segmentation/transforms.py | 120 +++++--------------------- 4 files changed, 53 insertions(+), 127 deletions(-) diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index e02434012f1..a3311ff5271 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -6,7 +6,7 @@ import torchvision from PIL import Image from pycocotools import mask as coco_mask -from transforms import Compose +from torchvision.prototype.transforms import Compose class FilterAndRemapCocoCategories: @@ -14,7 +14,9 @@ def __init__(self, categories, remap=True): self.categories = categories self.remap = remap - def __call__(self, image, anno): + def __call__(self, sample): + image, anno = sample + anno = [obj for obj in anno if obj["category_id"] in self.categories] if not self.remap: return image, anno @@ -42,7 +44,9 @@ def convert_coco_poly_to_mask(segmentations, height, width): class ConvertCocoPolysToMask: - def __call__(self, image, anno): + def __call__(self, sample): + image, anno = sample + w, h = image.size segmentations = [obj["segmentation"] for obj in anno] cats = [obj["category_id"] for obj in anno] diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..bc1a7d8b6a9 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,40 @@ import torch -import transforms as T +from torchvision.prototype import features, transforms as T -class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) +class WrapIntoFeatures(T.Transform): + def forward(self, sample): + image, segmentation_mask = sample + return image, features.SegmentationMask(segmentation_mask.squeeze(0), dtype=torch.int64) + - trans = [T.RandomResize(min_size, max_size)] +class SegmentationPresetTrain(T.Compose): + def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + transforms = [ + T.ToImageTensor(), + WrapIntoFeatures(), + T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size)), + ] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( + transforms.append(T.RandomHorizontalFlip(hflip_prob)) + transforms.extend( [ - T.RandomCrop(crop_size), - T.PILToTensor(), + T.RandomCrop(crop_size, pad_if_needed=True), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) - self.transforms = T.Compose(trans) + super().__init__(transforms) - def __call__(self, img, target): - return self.transforms(img, target) - -class SegmentationPresetEval: +class SegmentationPresetEval(T.Compose): def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( + super().__init__( [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), + T.ToImageTensor(), + WrapIntoFeatures(), + T.Resize(base_size), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) - - def __call__(self, img, target): - return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index bb57e65b801..ec5f96025c0 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -36,7 +36,8 @@ def get_transform(train, args): weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - def preprocessing(img, target): + def preprocessing(sample): + img, target = sample img = trans(img) size = F.get_dimensions(img)[1:] target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) @@ -134,8 +135,8 @@ def main(args): else: torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True, args=args)) + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False, args=args)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..ee9af5eeb53 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,100 +1,20 @@ -import random - -import numpy as np -import torch -from torchvision import transforms as T -from torchvision.transforms import functional as F - - -def pad_if_smaller(img, size, fill=0): - min_size = min(img.size) - if min_size < size: - ow, oh = img.size - padh = size - oh if oh < size else 0 - padw = size - ow if ow < size else 0 - img = F.pad(img, (0, 0, padw, padh), fill=fill) - return img - - -class Compose: - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, image, target): - for t in self.transforms: - image, target = t(image, target) - return image, target - - -class RandomResize: - def __init__(self, min_size, max_size=None): - self.min_size = min_size - if max_size is None: - max_size = min_size - self.max_size = max_size - - def __call__(self, image, target): - size = random.randint(self.min_size, self.max_size) - image = F.resize(image, size) - target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) - return image, target - - -class RandomHorizontalFlip: - def __init__(self, flip_prob): - self.flip_prob = flip_prob - - def __call__(self, image, target): - if random.random() < self.flip_prob: - image = F.hflip(image) - target = F.hflip(target) - return image, target - - -class RandomCrop: - def __init__(self, size): - self.size = size - - def __call__(self, image, target): - image = pad_if_smaller(image, self.size) - target = pad_if_smaller(target, self.size, fill=255) - crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) - image = F.crop(image, *crop_params) - target = F.crop(target, *crop_params) - return image, target - - -class CenterCrop: - def __init__(self, size): - self.size = size - - def __call__(self, image, target): - image = F.center_crop(image, self.size) - target = F.center_crop(target, self.size) - return image, target - - -class PILToTensor: - def __call__(self, image, target): - image = F.pil_to_tensor(image) - target = torch.as_tensor(np.array(target), dtype=torch.int64) - return image, target - - -class ConvertImageDtype: - def __init__(self, dtype): - self.dtype = dtype - - def __call__(self, image, target): - image = F.convert_image_dtype(image, self.dtype) - return image, target - - -class Normalize: - def __init__(self, mean, std): - self.mean = mean - self.std = std - - def __call__(self, image, target): - image = F.normalize(image, mean=self.mean, std=self.std) - return image, target +from torchvision.prototype import features, transforms + + +class RandomCrop(transforms.RandomCrop): + def _transform(self, inpt, params): + if not isinstance(inpt, features.SegmentationMask): + return super()._transform(inpt, params) + + # `SegmentationMask`'s should be padded with 255 to indicate an area that should not be used in the loss + # calculation. See + # https://stackoverflow.com/questions/49629933/ground-truth-pixel-labels-in-pascal-voc-for-semantic-segmentation + # for details. + # FIXME: Using different values for `fill` based on the input type is not supported by `transforms.RandomCrop`. + # Thus, we emulate it here. See https://github.com/pytorch/vision/issues/6568. + fill = self.fill + try: + self.fill = 255 + return super()._transform(inpt, params) + finally: + self.fill = fill From 99e6c3665d96d84f2a3f57ea8c4649ef5c30f6a0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 09:56:10 +0200 Subject: [PATCH 13/49] [skip ci] From 94ac15d529f86eb050ae3913560442ff41bff0dd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 11:52:04 +0200 Subject: [PATCH 14/49] [skip ci] From 51307b7329e4f45e935b610dcf81248256e6376f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 12:16:42 +0200 Subject: [PATCH 15/49] [skip ci] fix workaround --- references/segmentation/presets.py | 3 ++- references/segmentation/transforms.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 21c53f318c3..e460ce54bd1 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,5 +1,6 @@ import torch from torchvision.prototype import features, transforms as T +from transforms import RandomCrop class WrapIntoFeatures(T.Transform): @@ -19,7 +20,7 @@ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, transforms.append(T.RandomHorizontalFlip(hflip_prob)) transforms.extend( [ - T.RandomCrop(crop_size, pad_if_needed=True), + RandomCrop(crop_size, pad_if_needed=True), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index ee9af5eeb53..e64400df362 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -3,11 +3,10 @@ class RandomCrop(transforms.RandomCrop): def _transform(self, inpt, params): - if not isinstance(inpt, features.SegmentationMask): + if not isinstance(inpt, features.Mask): return super()._transform(inpt, params) - # `SegmentationMask`'s should be padded with 255 to indicate an area that should not be used in the loss - # calculation. See + # `Mask`'s should be padded with 255 to indicate an area that should not be used in the loss calculation. See # https://stackoverflow.com/questions/49629933/ground-truth-pixel-labels-in-pascal-voc-for-semantic-segmentation # for details. # FIXME: Using different values for `fill` based on the input type is not supported by `transforms.RandomCrop`. From 9dad6e024aff5232db67668cd6089c50e9b45488 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 20:23:48 +0200 Subject: [PATCH 16/49] only wrap segmentation mask --- references/segmentation/presets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index e460ce54bd1..2eb93c36678 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,18 +1,18 @@ import torch from torchvision.prototype import features, transforms as T +from torchvision.prototype.transforms import functional as F from transforms import RandomCrop class WrapIntoFeatures(T.Transform): def forward(self, sample): image, mask = sample - return image, features.Mask(mask.squeeze(0), dtype=torch.int64) + return image, features.Mask(F.to_image_tensor(mask).squeeze(0), dtype=torch.int64) class SegmentationPresetTrain(T.Compose): def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): transforms = [ - T.ToImageTensor(), WrapIntoFeatures(), T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size)), ] @@ -21,6 +21,7 @@ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, transforms.extend( [ RandomCrop(crop_size, pad_if_needed=True), + T.ToImageTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] @@ -32,9 +33,9 @@ class SegmentationPresetEval(T.Compose): def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): super().__init__( [ - T.ToImageTensor(), WrapIntoFeatures(), T.Resize(base_size), + T.ToImageTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] From f5f17169dfa30c6cab22e955768671165476a3ef Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 21:43:39 +0200 Subject: [PATCH 17/49] fix pretrained weights test only --- references/segmentation/coco_utils.py | 1 - references/segmentation/train.py | 15 +++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index a3311ff5271..863790d9c91 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -94,7 +94,6 @@ def get_coco(root, image_set, transforms): PATHS = { "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), - # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) } CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] diff --git a/references/segmentation/train.py b/references/segmentation/train.py index ec5f96025c0..edaba9c9a3c 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,15 +11,18 @@ from coco_utils import get_coco from torch import nn from torch.optim.lr_scheduler import PolynomialLR -from torchvision.transforms import functional as F, InterpolationMode +from torchvision.prototype.transforms import Compose, functional as F, InterpolationMode def get_dataset(dir_path, name, image_set, transform): - def sbd(*args, **kwargs): - return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + def voc(*args, transforms, **kwargs): + return torchvision.datasets.VOCSegmentation(*args, transforms=Compose([transforms]), **kwargs) + + def sbd(*args, transforms, **kwargs): + return torchvision.datasets.SBDataset(*args, mode="segmentation", transforms=Compose([transforms]), **kwargs) paths = { - "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), + "voc": (dir_path, voc, 21), "voc_aug": (dir_path, sbd, 21), "coco": (dir_path, get_coco, 21), } @@ -39,9 +42,9 @@ def get_transform(train, args): def preprocessing(sample): img, target = sample img = trans(img) - size = F.get_dimensions(img)[1:] + size = F.get_image_size(img) target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) - return img, F.pil_to_tensor(target) + return img, F.to_image_tensor(target) return preprocessing else: From 2aefd09933d094b5875a109d9c594fa0060427bd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 14 Sep 2022 21:43:45 +0200 Subject: [PATCH 18/49] [skip ci] From a2893a1adc99de6a0f5b2e1050d20aaeaa22e100 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 14 Sep 2022 21:53:01 +0100 Subject: [PATCH 19/49] Restore get_dimensions --- references/segmentation/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index edaba9c9a3c..2bc26e83b4b 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -42,7 +42,7 @@ def get_transform(train, args): def preprocessing(sample): img, target = sample img = trans(img) - size = F.get_image_size(img) + size = F.get_dimensions(img)[1:] target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) return img, F.to_image_tensor(target) From e912976a5b30018fda5cc50d9169fc1b697ac84b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 11:39:30 +0200 Subject: [PATCH 20/49] fix segmentation transforms --- references/segmentation/presets.py | 9 +++++-- references/segmentation/transforms.py | 35 +++++++++++++++------------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 2eb93c36678..b3c7e96f862 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,7 +1,9 @@ +from collections import defaultdict + import torch from torchvision.prototype import features, transforms as T from torchvision.prototype.transforms import functional as F -from transforms import RandomCrop +from transforms import PadIfSmaller class WrapIntoFeatures(T.Transform): @@ -20,7 +22,10 @@ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, transforms.append(T.RandomHorizontalFlip(hflip_prob)) transforms.extend( [ - RandomCrop(crop_size, pad_if_needed=True), + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {features.Mask: 255})), + T.RandomCrop(crop_size), T.ToImageTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index e64400df362..46c3a7fb5de 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,19 +1,24 @@ -from torchvision.prototype import features, transforms +from torchvision.prototype import transforms +from torchvision.prototype.transforms import functional as F -class RandomCrop(transforms.RandomCrop): +class PadIfSmaller(transforms.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = transforms._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = transforms._utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + def _transform(self, inpt, params): - if not isinstance(inpt, features.Mask): - return super()._transform(inpt, params) + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = F._geometry._convert_fill_arg(fill) - # `Mask`'s should be padded with 255 to indicate an area that should not be used in the loss calculation. See - # https://stackoverflow.com/questions/49629933/ground-truth-pixel-labels-in-pascal-voc-for-semantic-segmentation - # for details. - # FIXME: Using different values for `fill` based on the input type is not supported by `transforms.RandomCrop`. - # Thus, we emulate it here. See https://github.com/pytorch/vision/issues/6568. - fill = self.fill - try: - self.fill = 255 - return super()._transform(inpt, params) - finally: - self.fill = fill + return F.pad(inpt, padding=params["padding"], fill=fill) From 2e7e16842abe3677997eaf22393e7423c8706117 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 11:40:24 +0200 Subject: [PATCH 21/49] [skip ci] From 585c64a29e114a666f0007a10c2c586fd2c42f95 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 12:09:06 +0200 Subject: [PATCH 22/49] fix mask rewrapping --- references/segmentation/presets.py | 9 +-------- references/segmentation/train.py | 4 +++- references/segmentation/transforms.py | 10 +++++++++- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index b3c7e96f862..9d851a1506e 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -2,14 +2,7 @@ import torch from torchvision.prototype import features, transforms as T -from torchvision.prototype.transforms import functional as F -from transforms import PadIfSmaller - - -class WrapIntoFeatures(T.Transform): - def forward(self, sample): - image, mask = sample - return image, features.Mask(F.to_image_tensor(mask).squeeze(0), dtype=torch.int64) +from transforms import PadIfSmaller, WrapIntoFeatures class SegmentationPresetTrain(T.Compose): diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 2bc26e83b4b..50e20662e91 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -12,6 +12,7 @@ from torch import nn from torch.optim.lr_scheduler import PolynomialLR from torchvision.prototype.transforms import Compose, functional as F, InterpolationMode +from transforms import WrapIntoFeatures def get_dataset(dir_path, name, image_set, transform): @@ -38,13 +39,14 @@ def get_transform(train, args): elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() + wrap = WrapIntoFeatures() def preprocessing(sample): img, target = sample img = trans(img) size = F.get_dimensions(img)[1:] target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) - return img, F.to_image_tensor(target) + return wrap((img, target)) return preprocessing else: diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 46c3a7fb5de..485ebf71821 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,7 +1,15 @@ -from torchvision.prototype import transforms +import torch + +from torchvision.prototype import features, transforms from torchvision.prototype.transforms import functional as F +class WrapIntoFeatures(transforms.Transform): + def forward(self, sample): + image, mask = sample + return image, features.Mask(F.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + + class PadIfSmaller(transforms.Transform): def __init__(self, size, fill=0): super().__init__() From 93d7a325082f9f91a1328c4b54d323f90247cec4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 12:09:23 +0200 Subject: [PATCH 23/49] [skip ci] From 766af6c062b3c5dbabc0204fe69be2484da6cc1a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 20:15:35 +0100 Subject: [PATCH 24/49] Fix merge issue --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 61f4867cc0b..50b43a9848d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -231,7 +231,7 @@ def main(args): ) def collate_fn(batch): - return mixupcutmix(*default_collate(batch)) + return batch_transform(*default_collate(batch)) data_loader = torch.utils.data.DataLoader( dataset, From cb6c90ee8c273fb755f01e092c1c772ead68ae96 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 23:04:21 +0100 Subject: [PATCH 25/49] Tensor Backend + antialiasing=True --- references/classification/presets.py | 12 ++++++------ references/classification/transforms.py | 5 +++-- references/detection/coco_utils.py | 3 ++- references/detection/presets.py | 19 +++++++++---------- references/detection/train.py | 2 +- references/segmentation/presets.py | 6 ++---- references/segmentation/transforms.py | 2 +- 7 files changed, 24 insertions(+), 25 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 456af31bdc8..4ddaec18bff 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -18,10 +18,11 @@ def __init__( random_erase_prob=0.0, center_crop=False, ): - trans = ( - [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = [transforms.ToImageTensor()] + trans.append( + transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True) if center_crop - else [transforms.CenterCrop(crop_size)] + else transforms.CenterCrop(crop_size) ) if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) @@ -37,7 +38,6 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ - transforms.ToImageTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -64,9 +64,9 @@ def __init__( self.transforms = transforms.Compose( [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), transforms.ToImageTensor(), + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 2438bc45730..77c6997a77e 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -1,8 +1,9 @@ from torch import nn from torchvision.prototype import features +from torchvision.prototype.transforms import functional as F class WrapIntoFeatures(nn.Module): def forward(self, sample): - input, target = sample - return features.Image(input), features.Label(target) + image, target = sample + return F.to_image_tensor(image), features.Label(target) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 5257daabd8a..b392a38c7c6 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -7,6 +7,7 @@ from pycocotools import mask as coco_mask from pycocotools.coco import COCO from torchvision.prototype import features, transforms as T +from torchvision.prototype.transforms import functional as F def convert_coco_poly_to_mask(segmentations, height, width): @@ -101,7 +102,7 @@ def __call__(self, sample): iscrowd=target["iscrowd"].bool().tolist(), ) - return image, wrapped_target + return F.to_image_tensor(image), wrapped_target def _coco_remove_images_without_annotations(dataset, cat_list=None): diff --git a/references/detection/presets.py b/references/detection/presets.py index af4dcf72bfe..1e89815f32e 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,5 +1,7 @@ +from collections import defaultdict + import torch -from torchvision.prototype import transforms as T +from torchvision.prototype import features, transforms as T class DetectionPresetTrain(T.Compose): @@ -7,38 +9,35 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104 if data_augmentation == "hflip": transforms = [ T.RandomHorizontalFlip(p=hflip_prob), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] elif data_augmentation == "lsj": transforms = [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), + T.ScaleJitter(target_size=(1024, 1024), antialias=True), + T.FixedSizeCrop(size=(1024, 1024), fill=defaultdict(lambda: mean, {features.Mask: 0})), T.RandomHorizontalFlip(p=hflip_prob), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] elif data_augmentation == "multiscale": transforms = [ - T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333), + T.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True + ), T.RandomHorizontalFlip(p=hflip_prob), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] elif data_augmentation == "ssd": transforms = [ T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), + T.RandomZoomOut(fill=defaultdict(lambda: mean, {features.Mask: 0})), T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] elif data_augmentation == "ssdlite": transforms = [ T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] else: diff --git a/references/detection/train.py b/references/detection/train.py index 0662adcd173..6146da1adec 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -36,7 +36,7 @@ def copypaste_collate_fn(batch): - copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) + copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR, antialias=True) return copypaste(*utils.collate_fn(batch)) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 9d851a1506e..84da2229731 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -9,7 +9,7 @@ class SegmentationPresetTrain(T.Compose): def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): transforms = [ WrapIntoFeatures(), - T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size)), + T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True), ] if hflip_prob > 0: transforms.append(T.RandomHorizontalFlip(hflip_prob)) @@ -19,7 +19,6 @@ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, # different from the padding in `RandomCrop` if `pad_if_needed=True`. PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {features.Mask: 255})), T.RandomCrop(crop_size), - T.ToImageTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] @@ -32,8 +31,7 @@ def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, super().__init__( [ WrapIntoFeatures(), - T.Resize(base_size), - T.ToImageTensor(), + T.Resize(base_size, antialias=True), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 485ebf71821..46b2609c1a6 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -7,7 +7,7 @@ class WrapIntoFeatures(transforms.Transform): def forward(self, sample): image, mask = sample - return image, features.Mask(F.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + return F.to_image_tensor(image), features.Mask(F.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) class PadIfSmaller(transforms.Transform): From e9c480e905f4d42a2b0f6a92af40d66096918fab Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 23:04:43 +0100 Subject: [PATCH 26/49] Switch to view to reshape to avoid incompatibilities with size/stride --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4613a95e926..1866056b5c7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -880,7 +880,7 @@ def _scale_channel(img_chan: Tensor) -> Tensor: if img_chan.is_cuda: hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) else: - hist = torch.bincount(img_chan.view(-1), minlength=256) + hist = torch.bincount(img_chan.reshape(-1), minlength=256) nonzero_hist = hist[hist != 0] step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") From 6ef4d828e4d8dbdd2a98790108fed0ec6def6469 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 25 Sep 2022 16:40:04 +0100 Subject: [PATCH 27/49] Cherrypick PR #6642 --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 4ddaec18bff..4a7818d0377 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -20,9 +20,9 @@ def __init__( ): trans = [transforms.ToImageTensor()] trans.append( - transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True) + transforms.CenterCrop(crop_size) if center_crop - else transforms.CenterCrop(crop_size) + else transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True) ) if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) From 758de467eea8a90c97cf062e7f5edf39133eb0e5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Oct 2022 13:48:00 +0200 Subject: [PATCH 28/49] [skip ci] add support for video_classification --- references/video_classification/datasets.py | 12 ++++++++---- references/video_classification/presets.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/references/video_classification/datasets.py b/references/video_classification/datasets.py index dec1e16b856..c1c68431518 100644 --- a/references/video_classification/datasets.py +++ b/references/video_classification/datasets.py @@ -1,14 +1,18 @@ from typing import Tuple -import torchvision -from torch import Tensor +import torch +from torchvision import datasets +from torchvision.prototype import features -class KineticsWithVideoId(torchvision.datasets.Kinetics): - def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: +class KineticsWithVideoId(datasets.Kinetics): + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, features.Label, int]: video, audio, info, video_idx = self.video_clips.get_clip(idx) label = self.samples[video_idx][1] + video = features.Video(video) + label = features.Label(label, categories=self.classes) + if self.transform is not None: video = self.transform(video) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index ef774052257..cac7a1f62e1 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,5 +1,5 @@ import torch -from torchvision.transforms import transforms +from torchvision.prototype import transforms from transforms import ConvertBCHWtoCBHW @@ -19,7 +19,13 @@ def __init__( ] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - trans.extend([transforms.Normalize(mean=mean, std=std), transforms.RandomCrop(crop_size), ConvertBCHWtoCBHW()]) + trans.extend( + [ + transforms.Normalize(mean=mean, std=std), + transforms.RandomCrop(crop_size), + ConvertBCHWtoCBHW(), + ] + ) self.transforms = transforms.Compose(trans) def __call__(self, x): From 669b1baa43dbfd99e28bc8888d640919977229b5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 Oct 2022 12:31:13 +0100 Subject: [PATCH 29/49] Restoring original reference transforms so that test can run --- references/classification/transforms.py | 193 +++++++- references/detection/transforms.py | 594 ++++++++++++++++++++++++ references/segmentation/transforms.py | 120 ++++- 3 files changed, 894 insertions(+), 13 deletions(-) create mode 100644 references/detection/transforms.py diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 77c6997a77e..7a665057d8f 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -1,9 +1,194 @@ -from torch import nn +import math +from typing import Tuple + +import torch +from torch import Tensor from torchvision.prototype import features -from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms import functional as PF +from torchvision.transforms import functional as F -class WrapIntoFeatures(nn.Module): +class WrapIntoFeatures(torch.nn.Module): def forward(self, sample): image, target = sample - return F.to_image_tensor(image), features.Label(target) + return PF.to_image_tensor(image), features.Label(target) + + +# Original Transforms can be removed: + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + + if num_classes < 1: + raise ValueError( + f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" + ) + + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__() + if num_classes < 1: + raise ValueError("Please provide a valid positive value for the num_classes.") + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + _, H, W = F.get_dimensions(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s diff --git a/references/detection/transforms.py b/references/detection/transforms.py new file mode 100644 index 00000000000..080dc5a5b1c --- /dev/null +++ b/references/detection/transforms.py @@ -0,0 +1,594 @@ +# Original Transforms can be removed: +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torchvision +from torch import nn, Tensor +from torchvision import ops +from torchvision.transforms import functional as F, InterpolationMode, transforms as T + + +def _flip_coco_person_keypoints(kps, width): + flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + flipped_data = kps[:, flip_inds] + flipped_data[..., 0] = width - flipped_data[..., 0] + # Maintain COCO convention that if visibility == 0, then x, y = 0 + inds = flipped_data[..., 2] == 0 + flipped_data[inds] = 0 + return flipped_data + + +class Compose: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class RandomHorizontalFlip(T.RandomHorizontalFlip): + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if torch.rand(1) < self.p: + image = F.hflip(image) + if target is not None: + _, _, width = F.get_dimensions(image) + target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + if "keypoints" in target: + keypoints = target["keypoints"] + keypoints = _flip_coco_person_keypoints(keypoints, width) + target["keypoints"] = keypoints + return image, target + + +class PILToTensor(nn.Module): + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + image = F.pil_to_tensor(image) + return image, target + + +class ConvertImageDtype(nn.Module): + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + image = F.convert_image_dtype(image, self.dtype) + return image, target + + +class RandomIoUCrop(nn.Module): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if target is None: + raise ValueError("The targets can't be None for this transform.") + + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + _, orig_h, orig_w = F.get_dimensions(image) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return image, target + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # check for any valid boxes with centers within the crop area + cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2]) + cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + boxes = target["boxes"][is_within_crop_area] + ious = torchvision.ops.boxes.box_iou( + boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device) + ) + if ious.max() < min_jaccard_overlap: + continue + + # keep only valid boxes and perform cropping + target["boxes"] = boxes + target["labels"] = target["labels"][is_within_crop_area] + target["boxes"][:, 0::2] -= left + target["boxes"][:, 1::2] -= top + target["boxes"][:, 0::2].clamp_(min=0, max=new_w) + target["boxes"][:, 1::2].clamp_(min=0, max=new_h) + image = F.crop(image, top, left, new_h, new_w) + + return image, target + + +class RandomZoomOut(nn.Module): + def __init__( + self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + ): + super().__init__() + if fill is None: + fill = [0.0, 0.0, 0.0] + self.fill = fill + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid canvas side range provided {side_range}.") + self.p = p + + @torch.jit.unused + def _get_fill_value(self, is_pil): + # type: (bool) -> int + # We fake the type to make it work on JIT + return tuple(int(x) for x in self.fill) if is_pil else 0 + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + if torch.rand(1) >= self.p: + return image, target + + _, orig_h, orig_w = F.get_dimensions(image) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + + if torch.jit.is_scripting(): + fill = 0 + else: + fill = self._get_fill_value(F._is_pil_image(image)) + + image = F.pad(image, [left, top, right, bottom], fill=fill) + if isinstance(image, torch.Tensor): + # PyTorch's pad supports only integers on fill. So we need to overwrite the colour + v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) + image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[ + ..., :, (left + orig_w) : + ] = v + + if target is not None: + target["boxes"][:, 0::2] += left + target["boxes"][:, 1::2] += top + + return image, target + + +class RandomPhotometricDistort(nn.Module): + def __init__( + self, + contrast: Tuple[float, float] = (0.5, 1.5), + saturation: Tuple[float, float] = (0.5, 1.5), + hue: Tuple[float, float] = (-0.05, 0.05), + brightness: Tuple[float, float] = (0.875, 1.125), + p: float = 0.5, + ): + super().__init__() + self._brightness = T.ColorJitter(brightness=brightness) + self._contrast = T.ColorJitter(contrast=contrast) + self._hue = T.ColorJitter(hue=hue) + self._saturation = T.ColorJitter(saturation=saturation) + self.p = p + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + r = torch.rand(7) + + if r[0] < self.p: + image = self._brightness(image) + + contrast_before = r[1] < 0.5 + if contrast_before: + if r[2] < self.p: + image = self._contrast(image) + + if r[3] < self.p: + image = self._saturation(image) + + if r[4] < self.p: + image = self._hue(image) + + if not contrast_before: + if r[5] < self.p: + image = self._contrast(image) + + if r[6] < self.p: + channels, _, _ = F.get_dimensions(image) + permutation = torch.randperm(channels) + + is_pil = F._is_pil_image(image) + if is_pil: + image = F.pil_to_tensor(image) + image = F.convert_image_dtype(image) + image = image[..., permutation, :, :] + if is_pil: + image = F.to_pil_image(image) + + return image, target + + +class ScaleJitter(nn.Module): + """Randomly resizes the image and its bounding boxes within the specified scale range. + The class implements the Scale Jitter augmentation as described in the paper + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. + + Args: + target_size (tuple of ints): The target size for the transform provided in (height, weight) format. + scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the + range a <= scale <= b. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + """ + + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + _, orig_height, orig_width = F.get_dimensions(image) + + scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + + if target is not None: + target["boxes"][:, 0::2] *= new_width / orig_width + target["boxes"][:, 1::2] *= new_height / orig_height + if "masks" in target: + target["masks"] = F.resize( + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + ) + + return image, target + + +class FixedSizeCrop(nn.Module): + def __init__(self, size, fill=0, padding_mode="constant"): + super().__init__() + size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. + self.padding_mode = padding_mode + + def _pad(self, img, target, padding): + # Taken from the functional_tensor.py pad + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + padding = [pad_left, pad_top, pad_right, pad_bottom] + img = F.pad(img, padding, self.fill, self.padding_mode) + if target is not None: + target["boxes"][:, 0::2] += pad_left + target["boxes"][:, 1::2] += pad_top + if "masks" in target: + target["masks"] = F.pad(target["masks"], padding, 0, "constant") + + return img, target + + def _crop(self, img, target, top, left, height, width): + img = F.crop(img, top, left, height, width) + if target is not None: + boxes = target["boxes"] + boxes[:, 0::2] -= left + boxes[:, 1::2] -= top + boxes[:, 0::2].clamp_(min=0, max=width) + boxes[:, 1::2].clamp_(min=0, max=height) + + is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3]) + + target["boxes"] = boxes[is_valid] + target["labels"] = target["labels"][is_valid] + if "masks" in target: + target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width) + + return img, target + + def forward(self, img, target=None): + _, height, width = F.get_dimensions(img) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + if new_height != height or new_width != width: + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = torch.rand(1) + top = int(offset_height * r) + left = int(offset_width * r) + + img, target = self._crop(img, target, top, left, new_height, new_width) + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + if pad_bottom != 0 or pad_right != 0: + img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) + + return img, target + + +class RandomShortestSize(nn.Module): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: int, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = interpolation + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + _, orig_height, orig_width = F.get_dimensions(image) + + min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + + if target is not None: + target["boxes"][:, 0::2] *= new_width / orig_width + target["boxes"][:, 1::2] *= new_height / orig_height + if "masks" in target: + target["masks"] = F.resize( + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + ) + + return image, target + + +def _copy_paste( + image: torch.Tensor, + target: Dict[str, Tensor], + paste_image: torch.Tensor, + paste_target: Dict[str, Tensor], + blending: bool = True, + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, +) -> Tuple[torch.Tensor, Dict[str, Tensor]]: + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + return image, target + + # We have to please torch script by explicitly specifying dtype as torch.long + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection).to(torch.long) + + paste_masks = paste_target["masks"][random_selection] + paste_boxes = paste_target["boxes"][random_selection] + paste_labels = paste_target["labels"][random_selection] + + masks = target["masks"] + + # We resize source and paste data if they have different sizes + # This is something we introduced here as originally the algorithm works + # on equal-sized data (for example, coming from LSJ data augmentations) + size1 = image.shape[-2:] + size2 = paste_image.shape[-2:] + if size1 != size2: + paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) + paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) + # resize bboxes: + ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) + paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) + + paste_alpha_mask = paste_masks.sum(dim=0) > 0 + + if blending: + paste_alpha_mask = F.gaussian_blur( + paste_alpha_mask.unsqueeze(0), + kernel_size=(5, 5), + sigma=[ + 2.0, + ], + ) + + # Copy-paste images: + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) + + # Copy-paste masks: + masks = masks * (~paste_alpha_mask) + non_all_zero_masks = masks.sum((-1, -2)) > 0 + masks = masks[non_all_zero_masks] + + # Do a shallow copy of the target dict + out_target = {k: v for k, v in target.items()} + + out_target["masks"] = torch.cat([masks, paste_masks]) + + # Copy-paste boxes and labels + boxes = ops.masks_to_boxes(masks) + out_target["boxes"] = torch.cat([boxes, paste_boxes]) + + labels = target["labels"][non_all_zero_masks] + out_target["labels"] = torch.cat([labels, paste_labels]) + + # Update additional optional keys: area and iscrowd if exist + if "area" in target: + out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) + + if "iscrowd" in target and "iscrowd" in paste_target: + # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) + # For example, if previous transforms geometrically modifies masks/boxes/labels but + # does not update "iscrowd" + if len(target["iscrowd"]) == len(non_all_zero_masks): + iscrowd = target["iscrowd"][non_all_zero_masks] + paste_iscrowd = paste_target["iscrowd"][random_selection] + out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) + + # Check for degenerated boxes and remove them + boxes = out_target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + valid_targets = ~degenerate_boxes.any(dim=1) + + out_target["boxes"] = boxes[valid_targets] + out_target["masks"] = out_target["masks"][valid_targets] + out_target["labels"] = out_target["labels"][valid_targets] + + if "area" in out_target: + out_target["area"] = out_target["area"][valid_targets] + if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): + out_target["iscrowd"] = out_target["iscrowd"][valid_targets] + + return image, out_target + + +class SimpleCopyPaste(torch.nn.Module): + def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR): + super().__init__() + self.resize_interpolation = resize_interpolation + self.blending = blending + + def forward( + self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]] + ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]: + torch._assert( + isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]), + "images should be a list of tensors", + ) + torch._assert( + isinstance(targets, (list, tuple)) and len(images) == len(targets), + "targets should be a list of the same size as images", + ) + for target in targets: + # Can not check for instance type dict with inside torch.jit.script + # torch._assert(isinstance(target, dict), "targets item should be a dict") + for k in ["masks", "boxes", "labels"]: + torch._assert(k in target, f"Key {k} should be present in targets") + torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor") + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images: List[torch.Tensor] = [] + output_targets: List[Dict[str, Tensor]] = [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + output_image, output_data = _copy_paste( + image, + target, + paste_image, + paste_target, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + ) + output_images.append(output_image) + output_targets.append(output_data) + + return output_images, output_targets + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})" + return s diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 46b2609c1a6..9a8f375d28c 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,23 +1,28 @@ +import random + +import numpy as np import torch -from torchvision.prototype import features, transforms -from torchvision.prototype.transforms import functional as F +from torchvision import transforms as T +from torchvision.prototype import features, transforms as PT +from torchvision.prototype.transforms import functional as PF +from torchvision.transforms import functional as F -class WrapIntoFeatures(transforms.Transform): +class WrapIntoFeatures(PT.Transform): def forward(self, sample): image, mask = sample - return F.to_image_tensor(image), features.Mask(F.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + return PF.to_image_tensor(image), features.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) -class PadIfSmaller(transforms.Transform): +class PadIfSmaller(PT.Transform): def __init__(self, size, fill=0): super().__init__() self.size = size - self.fill = transforms._geometry._setup_fill_arg(fill) + self.fill = PT._geometry._setup_fill_arg(fill) def _get_params(self, sample): - _, height, width = transforms._utils.query_chw(sample) + _, height, width = PT._utils.query_chw(sample) padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] needs_padding = any(padding) return dict(padding=padding, needs_padding=needs_padding) @@ -27,6 +32,103 @@ def _transform(self, inpt, params): return inpt fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) + fill = PF._geometry._convert_fill_arg(fill) + + return PF.pad(inpt, padding=params["padding"], fill=fill) + + +# Original Transforms can be removed: + + +def pad_if_smaller(img, size, fill=0): + min_size = min(img.size) + if min_size < size: + ow, oh = img.size + padh = size - oh if oh < size else 0 + padw = size - ow if ow < size else 0 + img = F.pad(img, (0, 0, padw, padh), fill=fill) + return img + + +class Compose: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class RandomResize: + def __init__(self, min_size, max_size=None): + self.min_size = min_size + if max_size is None: + max_size = min_size + self.max_size = max_size + + def __call__(self, image, target): + size = random.randint(self.min_size, self.max_size) + image = F.resize(image, size) + target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) + return image, target + + +class RandomHorizontalFlip: + def __init__(self, flip_prob): + self.flip_prob = flip_prob + + def __call__(self, image, target): + if random.random() < self.flip_prob: + image = F.hflip(image) + target = F.hflip(target) + return image, target + + +class RandomCrop: + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = pad_if_smaller(image, self.size) + target = pad_if_smaller(target, self.size, fill=255) + crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) + image = F.crop(image, *crop_params) + target = F.crop(target, *crop_params) + return image, target + + +class CenterCrop: + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = F.center_crop(image, self.size) + target = F.center_crop(target, self.size) + return image, target + + +class PILToTensor: + def __call__(self, image, target): + image = F.pil_to_tensor(image) + target = torch.as_tensor(np.array(target), dtype=torch.int64) + return image, target + + +class ConvertImageDtype: + def __init__(self, dtype): + self.dtype = dtype + + def __call__(self, image, target): + image = F.convert_image_dtype(image, self.dtype) + return image, target + + +class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std - return F.pad(inpt, padding=params["padding"], fill=fill) + def __call__(self, image, target): + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target From 591a773ac0adb6cff2a2fbc970f10134cad675df Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 Oct 2022 19:08:32 +0100 Subject: [PATCH 30/49] Adding AA, Random Erase, MixUp/CutMix and a different resize/crop strategy. --- references/classification/train.py | 12 ++--- references/video_classification/presets.py | 44 +++++++++++++--- references/video_classification/train.py | 51 +++++++++++++++++-- references/video_classification/transforms.py | 8 +++ 4 files changed, 96 insertions(+), 19 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 43c33df7db5..6fcfc8c175a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -131,19 +131,15 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) else: - auto_augment_policy = getattr(args, "auto_augment", None) - random_erase_prob = getattr(args, "random_erase", 0.0) - ra_magnitude = args.ra_magnitude - augmix_severity = args.augmix_severity dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, - auto_augment_policy=auto_augment_policy, - random_erase_prob=random_erase_prob, - ra_magnitude=ra_magnitude, - augmix_severity=augmix_severity, + auto_augment_policy=args.auto_augment, + random_erase_prob=args.random_erase, + ra_magnitude=args.ra_magnitude, + augmix_severity=args.augmix_severity, ), target_transform=lambda target: features.Label(target), ) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index cac7a1f62e1..94e4cf20e38 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -11,21 +11,41 @@ def __init__( resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), + interpolation=transforms.InterpolationMode.BILINEAR, hflip_prob=0.5, + auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, + random_erase_prob=0.0, ): trans = [ - transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), + transforms.RandomShortestSize( + min_size=resize_size[0], max_size=resize_size[1], interpolation=interpolation, antialias=True + ), + transforms.RandomCrop(crop_size), ] if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + elif auto_augment_policy == "ta_wide": + trans.append(transforms.TrivialAugmentWide(interpolation=interpolation)) + elif auto_augment_policy == "augmix": + trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity)) + else: + aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy) + trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ + transforms.ConvertImageDtype(torch.float32), transforms.Normalize(mean=mean, std=std), - transforms.RandomCrop(crop_size), - ConvertBCHWtoCBHW(), ] ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + trans.append(ConvertBCHWtoCBHW()) + self.transforms = transforms.Compose(trans) def __call__(self, x): @@ -33,13 +53,21 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__( + self, + *, + crop_size, + resize_size, + mean=(0.43216, 0.394666, 0.37645), + std=(0.22803, 0.22145, 0.216989), + interpolation=transforms.InterpolationMode.BILINEAR, + ): self.transforms = transforms.Compose( [ + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), transforms.Normalize(mean=mean, std=std), - transforms.CenterCrop(crop_size), ConvertBCHWtoCBHW(), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index e26231bb914..404c766e0fe 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -13,6 +13,9 @@ from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler +from torchvision.prototype import features, transforms +from torchvision.transforms.functional import InterpolationMode +from transforms import WrapIntoFeatures def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -153,6 +156,7 @@ def main(args): val_crop_size = tuple(args.val_crop_size) train_resize_size = tuple(args.train_resize_size) train_crop_size = tuple(args.train_crop_size) + interpolation = InterpolationMode(args.interpolation) traindir = os.path.join(args.data_path, "train") valdir = os.path.join(args.data_path, "val") @@ -160,7 +164,15 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir, args) - transform_train = presets.VideoClassificationPresetTrain(crop_size=train_crop_size, resize_size=train_resize_size) + transform_train = presets.VideoClassificationPresetTrain( + crop_size=train_crop_size, + resize_size=train_resize_size, + interpolation=interpolation, + auto_augment_policy=args.auto_augment, + random_erase_prob=args.random_erase, + ra_magnitude=args.ra_magnitude, + augmix_severity=args.augmix_severity, + ) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -197,7 +209,11 @@ def main(args): weights = torchvision.models.get_weight(args.weights) transform_test = weights.transforms() else: - transform_test = presets.VideoClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size) + transform_test = presets.VideoClassificationPresetEval( + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + ) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -232,13 +248,33 @@ def main(args): train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler, shuffle=False) + train_collate_fn = collate_fn + num_classes = len(dataset.classes) + mixup_or_cutmix = [] + if args.mixup_alpha > 0.0: + mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0)) + if args.cutmix_alpha > 0.0: + mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0)) + if mixup_or_cutmix: + batch_transform = transforms.Compose( + [ + WrapIntoFeatures(), + transforms.LabelToOneHot(num_categories=num_classes), + transforms.ToDtype(torch.float, features.OneHotLabel), + transforms.RandomChoice(mixup_or_cutmix), + ] + ) + + def train_collate_fn(batch): + return batch_transform(*default_collate(batch)) + data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=collate_fn, + collate_fn=train_collate_fn, ) data_loader_test = torch.utils.data.DataLoader( @@ -396,6 +432,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) + parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") + parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") + parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") + parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") + parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") parser.add_argument( "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." ) @@ -404,6 +446,9 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + parser.add_argument( + "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" + ) parser.add_argument( "--val-resize-size", default=(128, 171), diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 2a7cc2a4a66..f651b19997a 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -1,6 +1,14 @@ import torch import torch.nn as nn +from torchvision.prototype import features + + +class WrapIntoFeatures(torch.nn.Module): + def forward(self, sample): + video, _, target, id = sample + return features.Video(video), features.Label(target), id + class ConvertBCHWtoCBHW(nn.Module): """Convert tensor from (B, C, H, W) to (C, B, H, W)""" From 9e95b7835a4a57935658a642507a3deeb31c4aa5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 13 Oct 2022 22:13:01 +0100 Subject: [PATCH 31/49] image_size to spatial_size --- references/detection/coco_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index b392a38c7c6..a8c36cc49b2 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -92,7 +92,7 @@ def __call__(self, sample): boxes=features.BoundingBox( target["boxes"], format=features.BoundingBoxFormat.XYXY, - image_size=(image.height, image.width), + spatial_size=(image.height, image.width), ), # TODO: add categories labels=features.Label(target["labels"], categories=None), From 00d1b9bb92262ea9e0859d8358a2f4fa5357e128 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 12:59:31 +0100 Subject: [PATCH 32/49] Update the RandomShortestSize behaviour on Video presets. --- references/video_classification/presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 94e4cf20e38..7d6db49ab5c 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -20,7 +20,7 @@ def __init__( ): trans = [ transforms.RandomShortestSize( - min_size=resize_size[0], max_size=resize_size[1], interpolation=interpolation, antialias=True + min_size=list(range(resize_size[0], resize_size[1] + 1)), interpolation=interpolation, antialias=True ), transforms.RandomCrop(crop_size), ] From ef3dc554747f2229b34184d84f2dc475b0d324fc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 13:05:38 +0100 Subject: [PATCH 33/49] Fix ToDtype transform to accept dictionaries. --- references/classification/train.py | 2 +- references/video_classification/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 6fcfc8c175a..ff374a42275 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -219,7 +219,7 @@ def main(args): [ WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), - transforms.ToDtype(torch.float, features.OneHotLabel), + transforms.ToDtype({features.OneHotLabel: torch.float}), transforms.RandomChoice(mixup_or_cutmix), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 404c766e0fe..6b59a05971d 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -260,7 +260,7 @@ def main(args): [ WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), - transforms.ToDtype(torch.float, features.OneHotLabel), + transforms.ToDtype({features.OneHotLabel: torch.float}), transforms.RandomChoice(mixup_or_cutmix), ] ) From 25c4664abf3a01c673780d41224f75bf344b7b82 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 14:04:31 +0100 Subject: [PATCH 34/49] Fix issue with collate and audio using Philip's proposal. --- references/video_classification/datasets.py | 6 +++--- references/video_classification/train.py | 13 +++---------- references/video_classification/transforms.py | 2 +- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/references/video_classification/datasets.py b/references/video_classification/datasets.py index c1c68431518..81af001ae2c 100644 --- a/references/video_classification/datasets.py +++ b/references/video_classification/datasets.py @@ -6,8 +6,8 @@ class KineticsWithVideoId(datasets.Kinetics): - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, features.Label, int]: - video, audio, info, video_idx = self.video_clips.get_clip(idx) + def __getitem__(self, idx): + video, _, info, video_idx = self.video_clips.get_clip(idx) label = self.samples[video_idx][1] video = features.Video(video) @@ -16,4 +16,4 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, features.La if self.transform is not None: video = self.transform(video) - return video, audio, label, video_idx + return video, label, video_idx diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 6b59a05971d..a403e9ef571 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -129,12 +129,6 @@ def _get_cache_path(filepath, args): return cache_path -def collate_fn(batch): - # remove audio from the batch - batch = [(d[0], d[2], d[3]) for d in batch] - return default_collate(batch) - - def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -248,7 +242,7 @@ def main(args): train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler, shuffle=False) - train_collate_fn = collate_fn + collate_fn = None num_classes = len(dataset.classes) mixup_or_cutmix = [] if args.mixup_alpha > 0.0: @@ -265,7 +259,7 @@ def main(args): ] ) - def train_collate_fn(batch): + def collate_fn(batch): return batch_transform(*default_collate(batch)) data_loader = torch.utils.data.DataLoader( @@ -274,7 +268,7 @@ def train_collate_fn(batch): sampler=train_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=train_collate_fn, + collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( @@ -283,7 +277,6 @@ def train_collate_fn(batch): sampler=test_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=collate_fn, ) print("Creating model") diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index f651b19997a..8f4ba9f273d 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -6,7 +6,7 @@ class WrapIntoFeatures(torch.nn.Module): def forward(self, sample): - video, _, target, id = sample + video, target, id = sample return features.Video(video), features.Label(target), id From 091948e0fefe17fe5c52feb5d854a8828a588908 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 14:07:51 +0100 Subject: [PATCH 35/49] Fix linter --- references/video_classification/datasets.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/references/video_classification/datasets.py b/references/video_classification/datasets.py index 81af001ae2c..1e506db3e2a 100644 --- a/references/video_classification/datasets.py +++ b/references/video_classification/datasets.py @@ -1,6 +1,3 @@ -from typing import Tuple - -import torch from torchvision import datasets from torchvision.prototype import features From 6b2358734fb42e24608cc966b76c8c9193f91939 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 15:32:19 +0100 Subject: [PATCH 36/49] Fix ToDtype parameters. --- references/classification/train.py | 2 +- references/video_classification/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index ff374a42275..8a7e6e44cf7 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -219,7 +219,7 @@ def main(args): [ WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), - transforms.ToDtype({features.OneHotLabel: torch.float}), + transforms.ToDtype({features.OneHotLabel: torch.float, features.Image: None}), transforms.RandomChoice(mixup_or_cutmix), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index a403e9ef571..b603020ecad 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -254,7 +254,7 @@ def main(args): [ WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), - transforms.ToDtype({features.OneHotLabel: torch.float}), + transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None}), transforms.RandomChoice(mixup_or_cutmix), ] ) From eb37f8f1ac5593a5733029fb9f53bcb5cc473aa4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 15:42:10 +0100 Subject: [PATCH 37/49] Wrapping id into a no-op. --- references/video_classification/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 8f4ba9f273d..e443ea3cd98 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -7,7 +7,7 @@ class WrapIntoFeatures(torch.nn.Module): def forward(self, sample): video, target, id = sample - return features.Video(video), features.Label(target), id + return features.Video(video), features.Label(target), features._Feature(id) class ConvertBCHWtoCBHW(nn.Module): From bb468ba82ca1503c9ba704ae76fe0948e6058b77 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 15:47:49 +0100 Subject: [PATCH 38/49] Define `_Feature` in the dict. --- references/video_classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index b603020ecad..77a0c0927a9 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -254,7 +254,7 @@ def main(args): [ WrapIntoFeatures(), transforms.LabelToOneHot(num_categories=num_classes), - transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None}), + transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None, features._Feature: None}), transforms.RandomChoice(mixup_or_cutmix), ] ) From 59288764ef757ae5d736cb61f872367a6f53ce54 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 16:20:38 +0100 Subject: [PATCH 39/49] Handling hot-encoded tensors in `accuracy` --- references/video_classification/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 934f62f66ae..2d576066f2c 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -161,6 +161,8 @@ def accuracy(output, target, topk=(1,)): with torch.inference_mode(): maxk = max(topk) batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() From b63e607651e51f8b6764b6eca7906f76922a4c26 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 17:10:47 +0100 Subject: [PATCH 40/49] Handle ConvertBCHWtoCBHW interactions with mixup/cutmix. --- references/video_classification/train.py | 3 ++- references/video_classification/transforms.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 77a0c0927a9..173bd48878b 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -15,7 +15,7 @@ from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode -from transforms import WrapIntoFeatures +from transforms import ConvertBCHWtoCBHW, WrapIntoFeatures def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -256,6 +256,7 @@ def main(args): transforms.LabelToOneHot(num_categories=num_classes), transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None, features._Feature: None}), transforms.RandomChoice(mixup_or_cutmix), + ConvertBCHWtoCBHW(), ] ) diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index e443ea3cd98..44c088828e6 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -7,11 +7,13 @@ class WrapIntoFeatures(torch.nn.Module): def forward(self, sample): video, target, id = sample + video = video.transpose(-4, -3) # convert back to (B, C, H, W) return features.Video(video), features.Label(target), features._Feature(id) class ConvertBCHWtoCBHW(nn.Module): """Convert tensor from (B, C, H, W) to (C, B, H, W)""" - def forward(self, vid: torch.Tensor) -> torch.Tensor: - return vid.permute(1, 0, 2, 3) + def forward(self, *inputs): + inputs[0].transpose_(-4, -3) + return inputs From d5f153218b4ec548fce174b4803c6686c565d772 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 20:06:03 +0100 Subject: [PATCH 41/49] Add Permute Transform. --- references/video_classification/presets.py | 7 +++-- references/video_classification/train.py | 6 +++-- references/video_classification/transforms.py | 11 ++++---- torchvision/prototype/transforms/__init__.py | 11 +++++++- torchvision/prototype/transforms/_misc.py | 27 +++++++++++++++++-- 5 files changed, 47 insertions(+), 15 deletions(-) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 7d6db49ab5c..b6db5fab151 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,6 +1,5 @@ import torch -from torchvision.prototype import transforms -from transforms import ConvertBCHWtoCBHW +from torchvision.prototype import features, transforms class VideoClassificationPresetTrain: @@ -44,7 +43,7 @@ def __init__( ) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) - trans.append(ConvertBCHWtoCBHW()) + trans.append(transforms.Permute({torch.Tensor: (1, 0, 2, 3), features.Label: None})) self.transforms = transforms.Compose(trans) @@ -68,7 +67,7 @@ def __init__( transforms.CenterCrop(crop_size), transforms.ConvertImageDtype(torch.float32), transforms.Normalize(mean=mean, std=std), - ConvertBCHWtoCBHW(), + transforms.Permute((1, 0, 2, 3)), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 173bd48878b..50228924bd3 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -15,7 +15,7 @@ from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode -from transforms import ConvertBCHWtoCBHW, WrapIntoFeatures +from transforms import WrapIntoFeatures def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -256,7 +256,9 @@ def main(args): transforms.LabelToOneHot(num_categories=num_classes), transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None, features._Feature: None}), transforms.RandomChoice(mixup_or_cutmix), - ConvertBCHWtoCBHW(), + transforms.Permute( + {features.Video: (0, 2, 1, 3, 4), features.OneHotLabel: None, features._Feature: None} + ), ] ) diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 44c088828e6..9b922db5ecb 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -6,14 +6,13 @@ class WrapIntoFeatures(torch.nn.Module): def forward(self, sample): - video, target, id = sample - video = video.transpose(-4, -3) # convert back to (B, C, H, W) - return features.Video(video), features.Label(target), features._Feature(id) + video_cthw, target, id = sample + video_tchw = video_cthw.transpose(-4, -3) + return features.Video(video_tchw), features.Label(target), features._Feature(id) class ConvertBCHWtoCBHW(nn.Module): """Convert tensor from (B, C, H, W) to (C, B, H, W)""" - def forward(self, *inputs): - inputs[0].transpose_(-4, -3) - return inputs + def forward(self, vid: torch.Tensor) -> torch.Tensor: + return vid.permute(1, 0, 2, 3) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5324db63496..4ff1e93ced8 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -40,7 +40,16 @@ TenCrop, ) from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype -from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype +from ._misc import ( + GaussianBlur, + Identity, + Lambda, + LinearTransformation, + Normalize, + Permute, + RemoveSmallBoundingBoxes, + ToDtype, +) from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b31c688dc30..b23dd2dd87a 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,6 +1,6 @@ import functools from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import PIL.Image @@ -148,7 +148,7 @@ class ToDtype(Transform): def _default_dtype(self, dtype: torch.dtype) -> torch.dtype: return dtype - def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None: + def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: super().__init__() if not isinstance(dtype, dict): # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. @@ -163,6 +163,29 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt.to(dtype=dtype) +class Permute(Transform): + _transformed_types = (torch.Tensor,) + + def _default_dims(self, dims: Sequence[int]) -> Sequence[int]: + return dims + + def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = defaultdict(functools.partial(self._default_dims, dims)) + self.dims = dims + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + dims = self.dims[type(inpt)] + if dims is None: + return inpt + output = inpt.permute(dims) + if isinstance(inpt, features._Feature): + # TODO: handle properly the colour space if Image/Video + output = inpt.wrap_like(inpt, output) + return output + + class RemoveSmallBoundingBoxes(Transform): _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) From 6a0a32c6b2cddbe4f1f1e7b283aec19f7ac6fb00 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Oct 2022 12:52:50 +0100 Subject: [PATCH 42/49] Switch to `TransposeDimensions` --- references/video_classification/presets.py | 4 ++-- references/video_classification/train.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index b6db5fab151..accc80da4fb 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -43,7 +43,7 @@ def __init__( ) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) - trans.append(transforms.Permute({torch.Tensor: (1, 0, 2, 3), features.Label: None})) + trans.append(transforms.TransposeDimensions((-3, -4))) self.transforms = transforms.Compose(trans) @@ -67,7 +67,7 @@ def __init__( transforms.CenterCrop(crop_size), transforms.ConvertImageDtype(torch.float32), transforms.Normalize(mean=mean, std=std), - transforms.Permute((1, 0, 2, 3)), + transforms.TransposeDimensions((-3, -4)), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 50228924bd3..0126791159d 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -256,9 +256,7 @@ def main(args): transforms.LabelToOneHot(num_categories=num_classes), transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None, features._Feature: None}), transforms.RandomChoice(mixup_or_cutmix), - transforms.Permute( - {features.Video: (0, 2, 1, 3, 4), features.OneHotLabel: None, features._Feature: None} - ), + transforms.TransposeDimensions((-3, -4)), ] ) From 9d0a0a3cff1c50e45e5746a494b8d6170cc46433 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 27 Oct 2022 13:56:02 +0100 Subject: [PATCH 43/49] Fix linter. --- references/video_classification/presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index accc80da4fb..d45c90a60be 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,5 +1,5 @@ import torch -from torchvision.prototype import features, transforms +from torchvision.prototype import transforms class VideoClassificationPresetTrain: From d8b520288601549ae2a95cf072a86c43b941ae7b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Nov 2022 16:38:19 +0000 Subject: [PATCH 44/49] Fix method location. --- references/segmentation/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 9a8f375d28c..2f023aa1e04 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -32,7 +32,7 @@ def _transform(self, inpt, params): return inpt fill = self.fill[type(inpt)] - fill = PF._geometry._convert_fill_arg(fill) + fill = PF._utils._convert_fill_arg(fill) return PF.pad(inpt, padding=params["padding"], fill=fill) From 959af2d2286538215abba2bcf7cdbaf5b4c311c8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 7 Nov 2022 11:02:58 +0000 Subject: [PATCH 45/49] Fixing minor bug --- references/segmentation/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 2f023aa1e04..c0c24e5babb 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -32,7 +32,7 @@ def _transform(self, inpt, params): return inpt fill = self.fill[type(inpt)] - fill = PF._utils._convert_fill_arg(fill) + fill = PT._utils._convert_fill_arg(fill) return PF.pad(inpt, padding=params["padding"], fill=fill) From 8f07159b8f1b6a6e961f0613431b7adcfafbee82 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Nov 2022 10:39:35 +0000 Subject: [PATCH 46/49] Convert to floats at the beginning. --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 0eb82a18fc8..04bacdcdf2a 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -19,6 +19,7 @@ def __init__( ): trans = [ transforms.ToImageTensor(), + transforms.ConvertImageDtype(torch.float), transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True), ] if hflip_prob > 0: @@ -35,7 +36,6 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ - transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] ) @@ -62,9 +62,9 @@ def __init__( self.transforms = transforms.Compose( [ transforms.ToImageTensor(), + transforms.ConvertImageDtype(torch.float), transforms.Resize(resize_size, interpolation=interpolation, antialias=True), transforms.CenterCrop(crop_size), - transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] ) From 8344ce9a1837da26f1fa1542bd7b18c407a0b515 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Nov 2022 10:41:40 +0000 Subject: [PATCH 47/49] Revert "Convert to floats at the beginning." This reverts commit 8f07159b8f1b6a6e961f0613431b7adcfafbee82. --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 04bacdcdf2a..0eb82a18fc8 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -19,7 +19,6 @@ def __init__( ): trans = [ transforms.ToImageTensor(), - transforms.ConvertImageDtype(torch.float), transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True), ] if hflip_prob > 0: @@ -36,6 +35,7 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ + transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] ) @@ -62,9 +62,9 @@ def __init__( self.transforms = transforms.Compose( [ transforms.ToImageTensor(), - transforms.ConvertImageDtype(torch.float), transforms.Resize(resize_size, interpolation=interpolation, antialias=True), transforms.CenterCrop(crop_size), + transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] ) From 8b530360cd5bfc82e66d7c12035c630592d3444f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Nov 2022 12:09:12 +0000 Subject: [PATCH 48/49] Switch to PIL backend --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 0eb82a18fc8..106c2d72572 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -18,7 +18,6 @@ def __init__( random_erase_prob=0.0, ): trans = [ - transforms.ToImageTensor(), transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True), ] if hflip_prob > 0: @@ -35,6 +34,7 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ + transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -61,9 +61,9 @@ def __init__( self.transforms = transforms.Compose( [ - transforms.ToImageTensor(), transforms.Resize(resize_size, interpolation=interpolation, antialias=True), transforms.CenterCrop(crop_size), + transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] From c7f2ac830905f65273d7b60cef22f538ddc34101 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Nov 2022 12:21:16 +0000 Subject: [PATCH 49/49] Revert "Switch to PIL backend" This reverts commit 8b530360cd5bfc82e66d7c12035c630592d3444f. --- references/classification/presets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 106c2d72572..0eb82a18fc8 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -18,6 +18,7 @@ def __init__( random_erase_prob=0.0, ): trans = [ + transforms.ToImageTensor(), transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True), ] if hflip_prob > 0: @@ -34,7 +35,6 @@ def __init__( trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -61,9 +61,9 @@ def __init__( self.transforms = transforms.Compose( [ + transforms.ToImageTensor(), transforms.Resize(resize_size, interpolation=interpolation, antialias=True), transforms.CenterCrop(crop_size), - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ]