diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 29c2bc1358a..046550209b0 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,6 @@ import itertools import re +from collections import defaultdict import numpy as np @@ -1988,3 +1989,154 @@ def test__transform(self, inpt): assert type(output) is type(inpt) assert output.shape[-4] == num_samples assert output.dtype == inpt.dtype + + +@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) +@pytest.mark.parametrize("label_type", (torch.Tensor, int)) +@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) +def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): + + image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) + if image_type is PIL.Image: + image = to_pil_image(image[0]) + elif image_type is torch.Tensor: + image = image.as_subclass(torch.Tensor) + assert is_simple_tensor(image) + + label = 1 if label_type is int else torch.tensor([1]) + + if dataset_return_type is dict: + sample = { + "image": image, + "label": label, + } + else: + sample = image, label + + t = transforms.Compose( + [ + transforms.RandomResizedCrop((224, 224)), + transforms.RandomHorizontalFlip(p=1), + transforms.RandAugment(), + transforms.TrivialAugmentWide(), + transforms.AugMix(), + transforms.AutoAugment(), + to_tensor(), + # TODO: ConvertImageDtype is a pass-through on PIL images, is that + # intended? This results in a failure if we convert to tensor after + # it, because the image would still be uint8 which make Normalize + # fail. + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), + transforms.RandomErasing(p=1), + ] + ) + + out = t(sample) + + assert type(out) == type(sample) + + if dataset_return_type is tuple: + out_image, out_label = out + else: + assert out.keys() == sample.keys() + out_image, out_label = out.values() + + assert out_image.shape[-2:] == (224, 224) + assert out_label == label + + +@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) +@pytest.mark.parametrize("label_type", (torch.Tensor, list)) +@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) +def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): + if data_augmentation == "hflip": + t = [ + transforms.RandomHorizontalFlip(p=1), + to_tensor(), + transforms.ConvertImageDtype(torch.float), + ] + elif data_augmentation == "lsj": + t = [ + transforms.ScaleJitter(target_size=(1024, 1024), antialias=True), + # Note: replaced FixedSizeCrop with RandomCrop, becuase we're + # leaving FixedSizeCrop in prototype for now, and it expects Label + # classes which we won't release yet. + # transforms.FixedSizeCrop( + # size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}) + # ), + transforms.RandomCrop((1024, 1024), pad_if_needed=True), + transforms.RandomHorizontalFlip(p=1), + to_tensor(), + transforms.ConvertImageDtype(torch.float), + ] + elif data_augmentation == "multiscale": + t = [ + transforms.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True + ), + transforms.RandomHorizontalFlip(p=1), + to_tensor(), + transforms.ConvertImageDtype(torch.float), + ] + elif data_augmentation == "ssd": + t = [ + transforms.RandomPhotometricDistort(p=1), + transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), + # TODO: put back IoUCrop once we remove its hard requirement for Labels + # transforms.RandomIoUCrop(), + transforms.RandomHorizontalFlip(p=1), + to_tensor(), + transforms.ConvertImageDtype(torch.float), + ] + elif data_augmentation == "ssdlite": + t = [ + # TODO: put back IoUCrop once we remove its hard requirement for Labels + # transforms.RandomIoUCrop(), + transforms.RandomHorizontalFlip(p=1), + to_tensor(), + transforms.ConvertImageDtype(torch.float), + ] + t = transforms.Compose(t) + + num_boxes = 5 + H = W = 250 + + image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)) + if image_type is PIL.Image: + image = to_pil_image(image[0]) + elif image_type is torch.Tensor: + image = image.as_subclass(torch.Tensor) + assert is_simple_tensor(image) + + label = torch.randint(0, 10, size=(num_boxes,)) + if label_type is list: + label = label.tolist() + + # TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks + boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) + boxes[:, 2:] += boxes[:, :2] + boxes = boxes.clamp(min=0, max=min(H, W)) + boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) + + masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) + + sample = { + "image": image, + "label": label, + "boxes": boxes, + "masks": masks, + } + + out = t(sample) + + if to_tensor is transforms.ToTensor and image_type is not datapoints.Image: + assert is_simple_tensor(out["image"]) + else: + assert isinstance(out["image"], datapoints.Image) + assert isinstance(out["label"], type(sample["label"])) + + out["label"] = torch.tensor(out["label"]) + assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 50b17068aaf..89bead236b2 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -37,10 +37,11 @@ def _flatten_and_extract_image_or_video( unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + needs_transform_list = self._needs_transform_list(flat_inputs) image_or_videos = [] - for idx, inpt in enumerate(flat_inputs): - if check_type( + for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): + if needs_transform and check_type( inpt, ( datapoints.Image, diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index a360e076b1d..09e313e5bed 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -169,7 +169,8 @@ def _permute_channels( if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) - output = inpt[..., permutation, :, :] + # TODO: Find a better fix than as_subclass??? + output = inpt[..., permutation, :, :].as_subclass(type(inpt)) if isinstance(orig_inpt, PIL.Image.Image): output = F.to_image_pil(output) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index c49306cc523..16c30565d36 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any: self._check_inputs(flat_inputs) - params = self._get_params(flat_inputs) + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) + + def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: # Below is a heuristic on how to deal with simple tensor inputs: # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. @@ -53,7 +64,8 @@ def forward(self, *inputs: Any) -> Any: # The heuristic should work well for most people in practice. The only case where it doesn't is if someone # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. # However, this case wasn't supported by transforms v1 either, so there is no BC concern. - flat_outputs = [] + + needs_transform_list = [] transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: needs_transform = True @@ -65,10 +77,8 @@ def forward(self, *inputs: Any) -> Any: transform_simple_tensor = False else: needs_transform = False - - flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) - - return tree_unflatten(flat_outputs, spec) + needs_transform_list.append(needs_transform) + return needs_transform_list def extra_repr(self) -> str: extra = [] @@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any: if torch.rand(1) >= self.p: return inputs - params = self._get_params(flat_inputs) + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) flat_outputs = [ - self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] return tree_unflatten(flat_outputs, spec)