diff --git a/references/classification/presets.py b/references/classification/presets.py index 5d1bf1cc714..0eb82a18fc8 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,5 +1,5 @@ import torch -from torchvision.transforms import autoaugment, transforms +from torchvision.prototype import transforms from torchvision.transforms.functional import InterpolationMode @@ -17,22 +17,24 @@ def __init__( augmix_severity=3, random_erase_prob=0.0, ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = [ + transforms.ToImageTensor(), + transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True), + ] if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + trans.append(transforms.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) + trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity)) else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy) + trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -59,9 +61,9 @@ def __init__( self.transforms = transforms.Compose( [ - transforms.Resize(resize_size, interpolation=interpolation), + transforms.ToImageTensor(), + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), transforms.CenterCrop(crop_size), - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] diff --git a/references/classification/train.py b/references/classification/train.py index 10ba22bce03..c32c1d28c5d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -4,14 +4,17 @@ import warnings import presets +from sampler import RASampler +from transforms import WrapIntoFeatures +import utils # usort: skip + import torch import torch.utils.data import torchvision -import transforms -import utils -from sampler import RASampler + from torch import nn from torch.utils.data.dataloader import default_collate +from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode @@ -144,6 +147,7 @@ def load_data(traindir, valdir, args): ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, ), + target_transform=lambda target: features.Label(target), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") @@ -168,7 +172,8 @@ def load_data(traindir, valdir, args): dataset_test = torchvision.datasets.ImageFolder( valdir, - preprocessing, + transform=preprocessing, + target_transform=lambda target: features.Label(target), ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") @@ -210,16 +215,23 @@ def main(args): collate_fn = None num_classes = len(dataset.classes) - mixup_transforms = [] + mixup_or_cutmix = [] if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0)) if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) - if mixup_transforms: - mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) + mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0)) + if mixup_or_cutmix: + batch_transform = transforms.Compose( + [ + WrapIntoFeatures(), + transforms.LabelToOneHot(num_categories=num_classes), + transforms.ToDtype({features.OneHotLabel: torch.float, features.Image: None}), + transforms.RandomChoice(mixup_or_cutmix), + ] + ) def collate_fn(batch): - return mixupcutmix(*default_collate(batch)) + return batch_transform(*default_collate(batch)) data_loader = torch.utils.data.DataLoader( dataset, diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 9a8ef7877d6..7a665057d8f 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -3,9 +3,20 @@ import torch from torch import Tensor +from torchvision.prototype import features +from torchvision.prototype.transforms import functional as PF from torchvision.transforms import functional as F +class WrapIntoFeatures(torch.nn.Module): + def forward(self, sample): + image, target = sample + return PF.to_image_tensor(image), features.Label(target) + + +# Original Transforms can be removed: + + class RandomMixup(torch.nn.Module): """Randomly apply Mixup to the provided batch and targets. The class implements the data augmentations as described in the paper diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 38c8279c35e..e71462fcf88 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,30 +1,13 @@ -import copy import os import torch import torch.utils.data import torchvision -import transforms as T + from pycocotools import mask as coco_mask from pycocotools.coco import COCO - - -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target +from torchvision.prototype import features, transforms as T +from torchvision.prototype.transforms import functional as F def convert_coco_poly_to_mask(segmentations, height, width): @@ -45,7 +28,8 @@ def convert_coco_poly_to_mask(segmentations, height, width): class ConvertCocoPolysToMask: - def __call__(self, image, target): + def __call__(self, sample): + image, target = sample w, h = image.size image_id = target["image_id"] @@ -100,6 +84,27 @@ def __call__(self, image, target): return image, target +class WrapIntoFeatures: + def __call__(self, sample): + image, target = sample + + wrapped_target = dict( + boxes=features.BoundingBox( + target["boxes"], + format=features.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), + ), + # TODO: add categories + labels=features.Label(target["labels"], categories=None), + masks=features.Mask(target["masks"]), + image_id=int(target["image_id"]), + area=target["area"].tolist(), + iscrowd=target["iscrowd"].bool().tolist(), + ) + + return F.to_image_tensor(image), wrapped_target + + def _coco_remove_images_without_annotations(dataset, cat_list=None): def _has_only_empty_bbox(anno): return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) @@ -225,10 +230,12 @@ def get_coco(root, image_set, transforms, mode="instances"): PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), - # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) } - t = [ConvertCocoPolysToMask()] + t = [ + ConvertCocoPolysToMask(), + WrapIntoFeatures(), + ] if transforms is not None: t.append(transforms) @@ -243,8 +250,6 @@ def get_coco(root, image_set, transforms, mode="instances"): if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) - # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) - return dataset diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e5d55f189d..0e9bfffdf8a 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device): outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {target["image_id"]: output for target, output in zip(targets, outputs)} evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..1e89815f32e 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,73 +1,56 @@ +from collections import defaultdict + import torch -import transforms as T +from torchvision.prototype import features, transforms as T -class DetectionPresetTrain: +class DetectionPresetTrain(T.Compose): def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): if data_augmentation == "hflip": - self.transforms = T.Compose( - [ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomHorizontalFlip(p=hflip_prob), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "lsj": - self.transforms = T.Compose( - [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.ScaleJitter(target_size=(1024, 1024), antialias=True), + T.FixedSizeCrop(size=(1024, 1024), fill=defaultdict(lambda: mean, {features.Mask: 0})), + T.RandomHorizontalFlip(p=hflip_prob), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "multiscale": - self.transforms = T.Compose( - [ - T.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True + ), + T.RandomHorizontalFlip(p=hflip_prob), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "ssd": - self.transforms = T.Compose( - [ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=defaultdict(lambda: mean, {features.Mask: 0})), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ConvertImageDtype(torch.float), + ] elif data_augmentation == "ssdlite": - self.transforms = T.Compose( - [ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms = [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.ConvertImageDtype(torch.float), + ] else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') - def __call__(self, img, target): - return self.transforms(img, target) + super().__init__(transforms) -class DetectionPresetEval: +class DetectionPresetEval(T.Compose): def __init__(self): - self.transforms = T.Compose( + super().__init__( [ - T.PILToTensor(), + T.ToImageTensor(), T.ConvertImageDtype(torch.float), ] ) - - def __call__(self, img, target): - return self.transforms(img, target) diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..6146da1adec 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -31,12 +31,12 @@ from coco_utils import get_coco, get_coco_kp from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler +from torchvision.prototype import transforms as T from torchvision.transforms import InterpolationMode -from transforms import SimpleCopyPaste def copypaste_collate_fn(batch): - copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) + copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR, antialias=True) return copypaste(*utils.collate_fn(batch)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..080dc5a5b1c 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,3 +1,4 @@ +# Original Transforms can be removed: from typing import Dict, List, Optional, Tuple, Union import torch diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index e02434012f1..863790d9c91 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -6,7 +6,7 @@ import torchvision from PIL import Image from pycocotools import mask as coco_mask -from transforms import Compose +from torchvision.prototype.transforms import Compose class FilterAndRemapCocoCategories: @@ -14,7 +14,9 @@ def __init__(self, categories, remap=True): self.categories = categories self.remap = remap - def __call__(self, image, anno): + def __call__(self, sample): + image, anno = sample + anno = [obj for obj in anno if obj["category_id"] in self.categories] if not self.remap: return image, anno @@ -42,7 +44,9 @@ def convert_coco_poly_to_mask(segmentations, height, width): class ConvertCocoPolysToMask: - def __call__(self, image, anno): + def __call__(self, sample): + image, anno = sample + w, h = image.size segmentations = [obj["segmentation"] for obj in anno] cats = [obj["category_id"] for obj in anno] @@ -90,7 +94,6 @@ def get_coco(root, image_set, transforms): PATHS = { "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), - # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) } CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..84da2229731 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,38 @@ +from collections import defaultdict + import torch -import transforms as T +from torchvision.prototype import features, transforms as T +from transforms import PadIfSmaller, WrapIntoFeatures -class SegmentationPresetTrain: +class SegmentationPresetTrain(T.Compose): def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) - - trans = [T.RandomResize(min_size, max_size)] + transforms = [ + WrapIntoFeatures(), + T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True), + ] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( + transforms.append(T.RandomHorizontalFlip(hflip_prob)) + transforms.extend( [ + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {features.Mask: 255})), T.RandomCrop(crop_size), - T.PILToTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) - self.transforms = T.Compose(trans) + super().__init__(transforms) - def __call__(self, img, target): - return self.transforms(img, target) - -class SegmentationPresetEval: +class SegmentationPresetEval(T.Compose): def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( + super().__init__( [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), + WrapIntoFeatures(), + T.Resize(base_size, antialias=True), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) - - def __call__(self, img, target): - return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..1a89d573ec8 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,15 +11,19 @@ from coco_utils import get_coco from torch import nn from torch.optim.lr_scheduler import PolynomialLR -from torchvision.transforms import functional as F, InterpolationMode +from torchvision.prototype.transforms import Compose, functional as F, InterpolationMode +from transforms import WrapIntoFeatures def get_dataset(dir_path, name, image_set, transform): - def sbd(*args, **kwargs): - return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + def voc(*args, transforms, **kwargs): + return torchvision.datasets.VOCSegmentation(*args, transforms=Compose([transforms]), **kwargs) + + def sbd(*args, transforms, **kwargs): + return torchvision.datasets.SBDataset(*args, mode="segmentation", transforms=Compose([transforms]), **kwargs) paths = { - "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), + "voc": (dir_path, voc, 21), "voc_aug": (dir_path, sbd, 21), "coco": (dir_path, get_coco, 21), } @@ -35,12 +39,14 @@ def get_transform(train, args): elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() + wrap = WrapIntoFeatures() - def preprocessing(img, target): + def preprocessing(sample): + img, target = sample img = trans(img) size = F.get_dimensions(img)[1:] target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) - return img, F.pil_to_tensor(target) + return wrap((img, target)) return preprocessing else: @@ -134,8 +140,8 @@ def main(args): else: torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True, args=args)) + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False, args=args)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..c0c24e5babb 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -2,10 +2,44 @@ import numpy as np import torch + from torchvision import transforms as T +from torchvision.prototype import features, transforms as PT +from torchvision.prototype.transforms import functional as PF from torchvision.transforms import functional as F +class WrapIntoFeatures(PT.Transform): + def forward(self, sample): + image, mask = sample + return PF.to_image_tensor(image), features.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + + +class PadIfSmaller(PT.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = PT._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = PT._utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = PT._utils._convert_fill_arg(fill) + + return PF.pad(inpt, padding=params["padding"], fill=fill) + + +# Original Transforms can be removed: + + def pad_if_smaller(img, size, fill=0): min_size = min(img.size) if min_size < size: diff --git a/references/video_classification/datasets.py b/references/video_classification/datasets.py index dec1e16b856..1e506db3e2a 100644 --- a/references/video_classification/datasets.py +++ b/references/video_classification/datasets.py @@ -1,15 +1,16 @@ -from typing import Tuple +from torchvision import datasets +from torchvision.prototype import features -import torchvision -from torch import Tensor - -class KineticsWithVideoId(torchvision.datasets.Kinetics): - def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: - video, audio, info, video_idx = self.video_clips.get_clip(idx) +class KineticsWithVideoId(datasets.Kinetics): + def __getitem__(self, idx): + video, _, info, video_idx = self.video_clips.get_clip(idx) label = self.samples[video_idx][1] + video = features.Video(video) + label = features.Label(label, categories=self.classes) + if self.transform is not None: video = self.transform(video) - return video, audio, label, video_idx + return video, label, video_idx diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index ef774052257..d45c90a60be 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,6 +1,5 @@ import torch -from torchvision.transforms import transforms -from transforms import ConvertBCHWtoCBHW +from torchvision.prototype import transforms class VideoClassificationPresetTrain: @@ -11,15 +10,41 @@ def __init__( resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), + interpolation=transforms.InterpolationMode.BILINEAR, hflip_prob=0.5, + auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, + random_erase_prob=0.0, ): trans = [ - transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), + transforms.RandomShortestSize( + min_size=list(range(resize_size[0], resize_size[1] + 1)), interpolation=interpolation, antialias=True + ), + transforms.RandomCrop(crop_size), ] if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - trans.extend([transforms.Normalize(mean=mean, std=std), transforms.RandomCrop(crop_size), ConvertBCHWtoCBHW()]) + trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + elif auto_augment_policy == "ta_wide": + trans.append(transforms.TrivialAugmentWide(interpolation=interpolation)) + elif auto_augment_policy == "augmix": + trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity)) + else: + aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy) + trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) + trans.extend( + [ + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize(mean=mean, std=std), + ] + ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + trans.append(transforms.TransposeDimensions((-3, -4))) + self.transforms = transforms.Compose(trans) def __call__(self, x): @@ -27,14 +52,22 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__( + self, + *, + crop_size, + resize_size, + mean=(0.43216, 0.394666, 0.37645), + std=(0.22803, 0.22145, 0.216989), + interpolation=transforms.InterpolationMode.BILINEAR, + ): self.transforms = transforms.Compose( [ + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), transforms.Normalize(mean=mean, std=std), - transforms.CenterCrop(crop_size), - ConvertBCHWtoCBHW(), + transforms.TransposeDimensions((-3, -4)), ] ) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index e26231bb914..0126791159d 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -13,6 +13,9 @@ from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler +from torchvision.prototype import features, transforms +from torchvision.transforms.functional import InterpolationMode +from transforms import WrapIntoFeatures def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -126,12 +129,6 @@ def _get_cache_path(filepath, args): return cache_path -def collate_fn(batch): - # remove audio from the batch - batch = [(d[0], d[2], d[3]) for d in batch] - return default_collate(batch) - - def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -153,6 +150,7 @@ def main(args): val_crop_size = tuple(args.val_crop_size) train_resize_size = tuple(args.train_resize_size) train_crop_size = tuple(args.train_crop_size) + interpolation = InterpolationMode(args.interpolation) traindir = os.path.join(args.data_path, "train") valdir = os.path.join(args.data_path, "val") @@ -160,7 +158,15 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir, args) - transform_train = presets.VideoClassificationPresetTrain(crop_size=train_crop_size, resize_size=train_resize_size) + transform_train = presets.VideoClassificationPresetTrain( + crop_size=train_crop_size, + resize_size=train_resize_size, + interpolation=interpolation, + auto_augment_policy=args.auto_augment, + random_erase_prob=args.random_erase, + ra_magnitude=args.ra_magnitude, + augmix_severity=args.augmix_severity, + ) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -197,7 +203,11 @@ def main(args): weights = torchvision.models.get_weight(args.weights) transform_test = weights.transforms() else: - transform_test = presets.VideoClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size) + transform_test = presets.VideoClassificationPresetEval( + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + ) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -232,6 +242,27 @@ def main(args): train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler, shuffle=False) + collate_fn = None + num_classes = len(dataset.classes) + mixup_or_cutmix = [] + if args.mixup_alpha > 0.0: + mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0)) + if args.cutmix_alpha > 0.0: + mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0)) + if mixup_or_cutmix: + batch_transform = transforms.Compose( + [ + WrapIntoFeatures(), + transforms.LabelToOneHot(num_categories=num_classes), + transforms.ToDtype({features.OneHotLabel: torch.float, features.Video: None, features._Feature: None}), + transforms.RandomChoice(mixup_or_cutmix), + transforms.TransposeDimensions((-3, -4)), + ] + ) + + def collate_fn(batch): + return batch_transform(*default_collate(batch)) + data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, @@ -247,7 +278,6 @@ def main(args): sampler=test_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=collate_fn, ) print("Creating model") @@ -396,6 +426,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) + parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") + parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") + parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") + parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") + parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") parser.add_argument( "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." ) @@ -404,6 +440,9 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + parser.add_argument( + "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" + ) parser.add_argument( "--val-resize-size", default=(128, 171), diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 2a7cc2a4a66..9b922db5ecb 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -1,6 +1,15 @@ import torch import torch.nn as nn +from torchvision.prototype import features + + +class WrapIntoFeatures(torch.nn.Module): + def forward(self, sample): + video_cthw, target, id = sample + video_tchw = video_cthw.transpose(-4, -3) + return features.Video(video_tchw), features.Label(target), features._Feature(id) + class ConvertBCHWtoCBHW(nn.Module): """Convert tensor from (B, C, H, W) to (C, B, H, W)""" diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 934f62f66ae..2d576066f2c 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -161,6 +161,8 @@ def accuracy(output, target, topk=(1,)): with torch.inference_mode(): maxk = max(topk) batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] _, pred = output.topk(maxk, 1, True, True) pred = pred.t()