diff --git a/references/classification/presets.py b/references/classification/presets.py index 5d1bf1cc714..1db2e37edcc 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,6 +1,5 @@ import torch -from torchvision.transforms import autoaugment, transforms -from torchvision.transforms.functional import InterpolationMode +import torchvision.transforms.v2 as transforms class ClassificationPresetTrain: @@ -10,29 +9,44 @@ def __init__( crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, + interpolation=transforms.InterpolationMode.BILINEAR, hflip_prob=0.5, auto_augment_policy=None, ra_magnitude=9, augmix_severity=3, random_erase_prob=0.0, + backend="pil", ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + trans = [] + + backend = backend.lower() + if backend == "datapoint": + trans.append(transforms.ToImageTensor()) + elif backend == "tensor": + trans.append(transforms.PILToTensor()) + else: + assert backend == "pil" + + trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + trans.append(transforms.RandomHorizontalFlip(p=hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + trans.append(transforms.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) + trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity)) else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy) + trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation)) + + if backend == "pil": + # Note: we could also just use pure tensors? + trans.append(transforms.ToImageTensor()) + trans.extend( [ - transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -54,18 +68,33 @@ def __init__( resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, + interpolation=transforms.InterpolationMode.BILINEAR, + backend="pil", ): + trans = [] - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) + backend = backend.lower() + if backend == "datapoint": + trans.append(transforms.ToImageTensor()) + elif backend == "tensor": + trans.append(transforms.PILToTensor()) + else: + assert backend == "pil" + + trans += [ + transforms.Resize(resize_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), + ] + + if backend == "pil": + trans.append(transforms.ToImageTensor()) + + trans += [ + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + + self.transforms = transforms.Compose(trans) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 10ba22bce03..d8f9441fdcf 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -4,15 +4,17 @@ import warnings import presets +from sampler import RASampler +import utils # usort: skip + import torch import torch.utils.data import torchvision + import transforms -import utils -from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate -from torchvision.transforms.functional import InterpolationMode +from torchvision.transforms.v2 import InterpolationMode def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): @@ -143,6 +145,7 @@ def load_data(traindir, valdir, args): random_erase_prob=random_erase_prob, ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, + backend=args.backend, ), ) if args.cache_dataset: @@ -163,12 +166,15 @@ def load_data(traindir, valdir, args): preprocessing = weights.transforms() else: preprocessing = presets.ClassificationPresetEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + backend=args.backend, ) dataset_test = torchvision.datasets.ImageFolder( valdir, - preprocessing, + transform=preprocessing, ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") @@ -507,6 +513,7 @@ def get_args_parser(add_help=True): "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive") return parser diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 38c8279c35e..5c57536d576 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,30 +1,12 @@ -import copy import os import torch import torch.utils.data import torchvision -import transforms as T + from pycocotools import mask as coco_mask from pycocotools.coco import COCO - - -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target +from torchvision.datasets import wrap_dataset_for_transforms_v2 def convert_coco_poly_to_mask(segmentations, height, width): @@ -44,8 +26,10 @@ def convert_coco_poly_to_mask(segmentations, height, width): return masks +# TODO: Is this still needed? class ConvertCocoPolysToMask: - def __call__(self, image, target): + def __call__(self, sample): + image, target = sample w, h = image.size image_id = target["image_id"] @@ -126,10 +110,10 @@ def _has_valid_annotation(anno): return True return False - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) + # if not isinstance(dataset, torchvision.datasets.CocoDetection): + # raise TypeError( + # f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" + # ) ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -201,23 +185,15 @@ def get_coco_api_from_dataset(dataset): break if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset - if isinstance(dataset, torchvision.datasets.CocoDetection): + # TODO: hmmmmm + if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( + getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection + ): return dataset.coco return convert_to_coco_api(dataset) -class CocoDetection(torchvision.datasets.CocoDetection): - def __init__(self, img_folder, ann_file, transforms): - super().__init__(img_folder, ann_file) - self._transforms = transforms - - def __getitem__(self, idx): - img, target = super().__getitem__(idx) - image_id = self.ids[idx] - target = dict(image_id=image_id, annotations=target) - if self._transforms is not None: - img, target = self._transforms(img, target) - return img, target +# TODO: Maybe not critical but the wrapper doesn't work on sub-classes def get_coco(root, image_set, transforms, mode="instances"): @@ -225,26 +201,19 @@ def get_coco(root, image_set, transforms, mode="instances"): PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), - # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) } - t = [ConvertCocoPolysToMask()] - - if transforms is not None: - t.append(transforms) - transforms = T.Compose(t) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = CocoDetection(img_folder, ann_file, transforms=transforms) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + # TODO: need to update target_keys to handle masks for segmentation! + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) - # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) - return dataset diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e5d55f189d..a2fb82fc2a7 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -26,7 +26,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + for t in targets: + assert t["labels"].shape[0] == t["boxes"].shape[0], f"{t['labels'].shape} {t['boxes'].shape}" + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -97,7 +99,7 @@ def evaluate(model, data_loader, device): outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {target["image_id"]: output for target, output in zip(targets, outputs)} evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index d12e14b540c..f55a715358f 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -164,7 +164,10 @@ def compute_aspect_ratios(dataset, indices=None): if hasattr(dataset, "get_height_and_width"): return _compute_aspect_ratios_custom_dataset(dataset, indices) - if isinstance(dataset, torchvision.datasets.CocoDetection): + # TODO: hmmmmm + if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( + getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection + ): return _compute_aspect_ratios_coco_dataset(dataset, indices) if isinstance(dataset, torchvision.datasets.VOCDetection): diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..01713db294c 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,73 +1,86 @@ +from collections import defaultdict + import torch -import transforms as T +import torchvision +import transforms as reference_transforms + +torchvision.disable_beta_transforms_warning() +import torchvision.transforms.v2 as T +from torchvision import datapoints + + +# TODO: Should we provide a transforms that filters-out keys? + +class DetectionPresetTrain(T.Compose): + def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0), backend="PIL"): + transforms = [] + + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + else: + assert backend == "pil" -class DetectionPresetTrain: - def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): if data_augmentation == "hflip": - self.transforms = T.Compose( - [ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "lsj": - self.transforms = T.Compose( - [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.ScaleJitter(target_size=(1024, 1024), antialias=True), + reference_transforms.FixedSizeCrop( + size=(1024, 1024), fill=defaultdict(lambda: mean, {datapoints.Mask: 0}) + ), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "multiscale": - self.transforms = T.Compose( - [ - T.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True + ), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssd": - self.transforms = T.Compose( - [ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=defaultdict(lambda: mean, {datapoints.Mask: 0})), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssdlite": - self.transforms = T.Compose( - [ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') - def __call__(self, img, target): - return self.transforms(img, target) + if backend == "pil": + # Note: we could also just use pure tensors? + transforms.append(T.ToImageTensor()) + transforms += [ + T.ConvertImageDtype(torch.float), + T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), + T.SanitizeBoundingBox(), + ] -class DetectionPresetEval: - def __init__(self): - self.transforms = T.Compose( - [ - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + super().__init__(transforms) + + +class DetectionPresetEval(T.Compose): + def __init__(self, backend="pil"): + + transforms = [] + + backend = backend.lower() + if backend == "tensor": + transforms.append(T.PILToTensor()) + else: # for datapoint **and** PIL + transforms.append(T.ToImageTensor()) - def __call__(self, img, target): - return self.transforms(img, target) + transforms.append(T.ConvertImageDtype(torch.float)) + super().__init__(transforms) diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..a991603610d 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -31,12 +31,12 @@ from coco_utils import get_coco, get_coco_kp from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler +from torchvision.prototype import transforms as T from torchvision.transforms import InterpolationMode -from transforms import SimpleCopyPaste def copypaste_collate_fn(batch): - copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) + copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR, antialias=True) return copypaste(*utils.collate_fn(batch)) @@ -50,13 +50,13 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: - return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) + return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation, backend=args.backend) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target: (trans(img), target) else: - return presets.DetectionPresetEval() + return presets.DetectionPresetEval(backend=args.backend) def get_args_parser(add_help=True): @@ -158,11 +158,13 @@ def get_args_parser(add_help=True): action="store_true", help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.", ) + parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive") return parser def main(args): + if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..080dc5a5b1c 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,3 +1,4 @@ +# Original Transforms can be removed: from typing import Dict, List, Optional, Tuple, Union import torch diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..92a0908ac8a 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,82 @@ +from collections import defaultdict + import torch -import transforms as T +import torchvision + +torchvision.disable_beta_transforms_warning() +import torchvision.transforms.v2 as T +from torchvision import datapoints +from transforms import PadIfSmaller, WrapIntoFeatures + + +class SegmentationPresetTrain(T.Compose): + def __init__( + self, + *, + base_size, + crop_size, + hflip_prob=0.5, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + backend="pil", + ): + + transforms = [] + transforms.append(WrapIntoFeatures()) -class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + else: + assert backend == "pil" + + transforms.append(T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True)) - trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( - [ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - self.transforms = T.Compose(trans) - - def __call__(self, img, target): - return self.transforms(img, target) - - -class SegmentationPresetEval: - def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( - [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img, target): - return self.transforms(img, target) + transforms.append(T.RandomHorizontalFlip(hflip_prob)) + + transforms += [ + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), + T.RandomCrop(crop_size), + ] + + if backend == "pil": + transforms.append(T.ToImageTensor()) + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + + super().__init__(transforms) + + +class SegmentationPresetEval(T.Compose): + def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil"): + transforms = [] + + transforms.append(WrapIntoFeatures()) + + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + else: + assert backend == "pil" + + transforms.append(T.Resize(base_size, antialias=True)) + + if backend == "pil": + transforms.append(T.ToImageTensor()) + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + super().__init__(transforms) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..b06bd9ae985 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -31,7 +31,7 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: - return presets.SegmentationPresetTrain(base_size=520, crop_size=480) + return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() @@ -44,7 +44,7 @@ def preprocessing(img, target): return preprocessing else: - return presets.SegmentationPresetEval(base_size=520) + return presets.SegmentationPresetEval(base_size=520, backend=args.backend) def criterion(inputs, target): @@ -306,6 +306,7 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..cea3668247e 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -2,10 +2,41 @@ import numpy as np import torch -from torchvision import transforms as T +import torchvision.transforms.v2 as PT +import torchvision.transforms.v2.functional as PF +from torchvision import datapoints, transforms as T from torchvision.transforms import functional as F +class WrapIntoFeatures(PT.Transform): + def forward(self, sample): + image, mask = sample + # return PF.to_image_tensor(image), datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + return image, datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + + +class PadIfSmaller(PT.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = PT._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = PT.utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = PT._utils._convert_fill_arg(fill) + + return PF.pad(inpt, padding=params["padding"], fill=fill) + + def pad_if_smaller(img, size, fill=0): min_size = min(img.size) if min_size < size: diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index b7e2a42259f..83d94da97cd 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -31,6 +31,10 @@ def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> da return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] +# TODO: This doesn't convert PIL? +# This means that the detection preset fails right now if we convert PIL -> +# Tensor at the very end: PIL images are passed-through so instead of getting a +# float tensor we actually get a uint8 which raises an error class ConvertDtype(Transform): """[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly.