diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index dc3de480d1f..fb7c7341992 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -3,7 +3,12 @@ import pytest import torch from common_utils import assert_equal -from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import ( + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_segmentation_masks, +) from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image, pil_to_tensor @@ -25,23 +30,23 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs): yield bounding_box.data -def parametrize(transforms_with_inputs): +def parametrize(transforms_with_inpts): return pytest.mark.parametrize( - ("transform", "input"), + ("transform", "inpt"), [ pytest.param( transform, - input, - id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", + inpt, + id=f"{type(transform).__name__}-{type(inpt).__module__}.{type(inpt).__name__}-{idx}", ) - for transform, inputs in transforms_with_inputs - for idx, input in enumerate(inputs) + for transform, inpts in transforms_with_inpts + for idx, inpt in enumerate(inpts) ], ) def parametrize_from_transforms(*transforms): - transforms_with_inputs = [] + transforms_with_inpts = [] for transform in transforms: for creation_fn in [ make_images, @@ -49,19 +54,20 @@ def parametrize_from_transforms(*transforms): make_one_hot_labels, make_vanilla_tensor_images, make_pil_images, + make_segmentation_masks, ]: - inputs = list(creation_fn()) + inpts = list(creation_fn()) try: - output = transform(inputs[0]) - except Exception: + output = transform(inpts[0]) + except (TypeError, RuntimeError): continue else: - if output is inputs[0]: + if output is inpts[0]: continue - transforms_with_inputs.append((transform, inputs)) + transforms_with_inpts.append((transform, inpts)) - return parametrize(transforms_with_inputs) + return parametrize(transforms_with_inpts) class TestSmoke: @@ -69,12 +75,14 @@ class TestSmoke: transforms.RandomErasing(p=1.0), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), + transforms.RandomResizedCrop([16, 16]), transforms.ConvertImageDtype(), transforms.RandomHorizontalFlip(), transforms.Pad(5), ) - def test_common(self, transform, input): - transform(input) + def test_common(self, transform, inpt): + output = transform(inpt) + assert type(output) == type(inpt) @parametrize( [ @@ -96,8 +104,8 @@ def test_common(self, transform, input): ] ] ) - def test_mixup_cutmix(self, transform, input): - transform(input) + def test_mixup_cutmix(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -127,8 +135,8 @@ def test_mixup_cutmix(self, transform, input): ) ] ) - def test_auto_augment(self, transform, input): - transform(input) + def test_auto_augment(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -144,8 +152,8 @@ def test_auto_augment(self, transform, input): ), ] ) - def test_normalize(self, transform, input): - transform(input) + def test_normalize(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -159,8 +167,8 @@ def test_normalize(self, transform, input): ) ] ) - def test_random_resized_crop(self, transform, input): - transform(input) + def test_random_resized_crop(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -188,58 +196,58 @@ def test_random_resized_crop(self, transform, input): ) ] ) - def test_convert_image_color_space(self, transform, input): - transform(input) + def test_convert_image_color_space(self, transform, inpt): + transform(inpt) @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) + def inpt_expected_image_tensor(self, p, dtype=torch.float32): + inpt = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) - return input, expected if p == 1 else input + return inpt, expected if p == 1 else inpt def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(input) + actual = transform(inpt) assert_equal(expected, actual) def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(to_pil_image(input)) + actual = transform(to_pil_image(inpt)) assert_equal(expected, pil_to_tensor(actual)) def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(features.Image(inpt)) assert_equal(features.Image(expected), actual) def test_features_segmentation_mask(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.SegmentationMask(input)) + actual = transform(features.SegmentationMask(inpt)) assert_equal(features.SegmentationMask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(input) + actual = transform(inpt) - expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else inpt + expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size @@ -247,52 +255,52 @@ def test_features_bounding_box(self, p): @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomVerticalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) + def inpt_expected_image_tensor(self, p, dtype=torch.float32): + inpt = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype) - return input, expected if p == 1 else input + return inpt, expected if p == 1 else inpt def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(input) + actual = transform(inpt) assert_equal(expected, actual) def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(to_pil_image(input)) + actual = transform(to_pil_image(inpt)) assert_equal(expected, pil_to_tensor(actual)) def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(features.Image(inpt)) assert_equal(features.Image(expected), actual) def test_features_segmentation_mask(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.SegmentationMask(input)) + actual = transform(features.SegmentationMask(inpt)) assert_equal(features.SegmentationMask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(input) + actual = transform(inpt) - expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else inpt + expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 30d9b833ec8..9a26f9a225b 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -489,16 +489,40 @@ def center_crop_segmentation_mask(): and callable(kernel) and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and "pil" not in name - and name - not in { - "to_image_tensor", - } + and name not in {"to_image_tensor"} ], ) def test_scriptable(kernel): jit.script(kernel) +@pytest.mark.parametrize( + "func", + [ + pytest.param(func, id=name) + for name, func in F.__dict__.items() + if not name.startswith("_") + and callable(func) + and all( + feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} + ) + and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av"} + ], +) +def test_functional_mid_level(func): + finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name] + for finfo in finfos: + for sample_input in finfo.sample_inputs(): + expected = finfo(sample_input) + kwargs = dict(sample_input.kwargs) + for key in ["format", "image_size"]: + if key in kwargs: + del kwargs[key] + output = func(*sample_input.args, **kwargs) + torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + break + + @pytest.mark.parametrize( ("functional_info", "sample_input"), [ diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index cd5cdc69836..4e119e7bb25 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -69,3 +69,87 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: return BoundingBox.new_like( self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format ) + + def horizontal_flip(self) -> BoundingBox: + output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) + return BoundingBox.new_like(self, output) + + def vertical_flip(self) -> BoundingBox: + output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) + return BoundingBox.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox: + interpolation, antialias # unused + output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) + return BoundingBox.new_like(self, output, image_size=size, dtype=output.dtype) + + def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: + output = self._F.crop_bounding_box(self, self.format, top, left) + return BoundingBox.new_like(self, output, image_size=(height, width)) + + def center_crop(self, output_size) -> BoundingBox: + output = self._F.center_crop_bounding_box( + self, format=self.format, output_size=output_size, image_size=self.image_size + ) + return BoundingBox.new_like(self, output, image_size=output_size) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> BoundingBox: + interpolation, antialias # unused + output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) + return BoundingBox.new_like(self, output, image_size=size, dtype=output.dtype) + + def pad(self, padding, *, fill, padding_mode) -> BoundingBox: + fill # unused + if padding_mode not in ["constant"]: + raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") + + output = self._F.pad_bounding_box(self, padding, format=self.format) + + # Update output image size: + # TODO: remove the import below and make _parse_pad_padding available + from torchvision.transforms.functional_tensor import _parse_pad_padding + + left, top, right, bottom = _parse_pad_padding(padding) + height, width = self.image_size + height += top + bottom + width += left + right + + return BoundingBox.new_like(self, output, image_size=(height, width)) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> BoundingBox: + interpolation, fill # unused + output = self._F.rotate_bounding_box( + self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center + ) + # TODO: update output image size if expand is True + if expand: + raise RuntimeError("Not yet implemented") + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> BoundingBox: + interpolation, fill # unused + output = self._F.affine_bounding_box( + self, + self.format, + self.image_size, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def perspective(self, perspective_coeffs, *, interpolation, fill) -> BoundingBox: + interpolation, fill # unused + output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def erase(self, *args) -> BoundingBox: + raise TypeError("Erase transformation does not support bounding boxes") + + def mixup(self, *args) -> BoundingBox: + raise TypeError("Mixup transformation does not support bounding boxes") + + def cutmix(self, *args) -> BoundingBox: + raise TypeError("Cutmix transformation does not support bounding boxes") diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index f8026b4d34d..0d4e9321977 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -16,7 +16,7 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - return cast( + feature = cast( F, torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -25,6 +25,13 @@ def __new__( ), ) + # To avoid circular dependency between features and transforms + from ..transforms import functional + + feature._F = functional + + return feature + @classmethod def new_like( cls: Type[F], @@ -83,3 +90,118 @@ def __torch_function__( return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) else: return output + + def horizontal_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def vertical_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resize(self, size, *, interpolation, max_size, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def crop(self, top: int, left: int, height: int, width: int): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def center_crop(self, output_size): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def pad(self, padding, *, fill, padding_mode): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def rotate(self, angle, *, interpolation, expand, fill, center): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def perspective(self, perspective_coeffs, *, interpolation, fill): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_brightness(self, brightness_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_saturation(self, saturation_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_contrast(self, contrast_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_sharpness(self, sharpness_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_hue(self, hue_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_gamma(self, gamma: float, gain: float = 1): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def posterize(self, bits: int): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def solarize(self, threshold: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def autocontrast(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def equalize(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def erase(self, i, j, h, w, v): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def mixup(self, lam): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def cutmix(self, *, box, lam_adjusted): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 9206a844b6d..f7bb24fb427 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -109,3 +109,123 @@ def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) + + def horizontal_flip(self) -> Image: + output = self._F.horizontal_flip_image_tensor(self) + return Image.new_like(self, output) + + def vertical_flip(self) -> Image: + output = self._F.vertical_flip_image_tensor(self) + return Image.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> Image: + output = self._F.resize_image_tensor( + self, size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + return Image.new_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Image: + output = self._F.crop_image_tensor(self, top, left, height, width) + return Image.new_like(self, output) + + def center_crop(self, output_size) -> Image: + output = self._F.center_crop_image_tensor(self, output_size=output_size) + return Image.new_like(self, output) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> Image: + output = self._F.resized_crop_image_tensor( + self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias + ) + return Image.new_like(self, output) + + def pad(self, padding, *, fill, padding_mode) -> Image: + output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) + return Image.new_like(self, output) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> Image: + output = self._F.rotate_image_tensor( + self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Image.new_like(self, output) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> Image: + output = self._F.affine_image_tensor( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Image.new_like(self, output) + + def perspective(self, perspective_coeffs, *, interpolation, fill) -> Image: + output = self._F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) + return Image.new_like(self, output) + + def adjust_brightness(self, brightness_factor: float) -> Image: + output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) + return Image.new_like(self, output) + + def adjust_saturation(self, saturation_factor: float) -> Image: + output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) + return Image.new_like(self, output) + + def adjust_contrast(self, contrast_factor: float) -> Image: + output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) + return Image.new_like(self, output) + + def adjust_sharpness(self, sharpness_factor: float) -> Image: + output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) + return Image.new_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Image: + output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) + return Image.new_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: + output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) + return Image.new_like(self, output) + + def posterize(self, bits: int) -> Image: + output = self._F.posterize_image_tensor(self, bits=bits) + return Image.new_like(self, output) + + def solarize(self, threshold: float) -> Image: + output = self._F.solarize_image_tensor(self, threshold=threshold) + return Image.new_like(self, output) + + def autocontrast(self) -> Image: + output = self._F.autocontrast_image_tensor(self) + return Image.new_like(self, output) + + def equalize(self) -> Image: + output = self._F.equalize_image_tensor(self) + return Image.new_like(self, output) + + def invert(self) -> Image: + output = self._F.invert_image_tensor(self) + return Image.new_like(self, output) + + def erase(self, i, j, h, w, v) -> Image: + output = self._F.erase_image_tensor(self, i, j, h, w, v) + return Image.new_like(self, output) + + def mixup(self, lam: float) -> Image: + if self.ndim < 4: + raise ValueError("Need a batch of images") + output = self.clone() + output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) + return Image.new_like(self, output) + + def cutmix(self, *, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image: + lam_adjusted # unused + if self.ndim < 4: + raise ValueError("Need a batch of images") + x1, y1, x2, y2 = box + image_rolled = self.roll(1, -4) + output = self.clone() + output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index e3433b7bb08..1b1ded61c1c 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, cast, Union +from typing import Any, Optional, Sequence, cast, Union, Tuple import torch from torchvision.prototype.utils._internal import apply_recursively @@ -77,3 +77,14 @@ def new_like( return super().new_like( other, data, categories=categories if categories is not None else other.categories, **kwargs ) + + def mixup(self, lam) -> OneHotLabel: + if self.ndim < 2: + raise ValueError("Need a batch of one hot labels") + output = self.clone() + output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) + return OneHotLabel.new_like(self, output) + + def cutmix(self, *, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel: + box # unused + return self.mixup(lam_adjusted) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index dc41697ae9b..9a90b6ccc53 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,5 +1,68 @@ +from __future__ import annotations + from ._feature import _Feature class SegmentationMask(_Feature): - pass + def horizontal_flip(self) -> SegmentationMask: + output = self._F.horizontal_flip_segmentation_mask(self) + return SegmentationMask.new_like(self, output) + + def vertical_flip(self) -> SegmentationMask: + output = self._F.vertical_flip_segmentation_mask(self) + return SegmentationMask.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> SegmentationMask: + interpolation, antialias # unused + output = self._F.resize_segmentation_mask(self, size, max_size=max_size) + return SegmentationMask.new_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, top, left, height, width) + return SegmentationMask.new_like(self, output) + + def center_crop(self, output_size) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, output_size=output_size) + return SegmentationMask.new_like(self, output) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> SegmentationMask: + # TODO: untested right now + interpolation, antialias # unused + output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=list(size)) + return SegmentationMask.new_like(self, output) + + def pad(self, padding, *, fill, padding_mode) -> SegmentationMask: + fill # unused + output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) + return SegmentationMask.new_like(self, output) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> SegmentationMask: + interpolation, fill # unused + output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center) + return SegmentationMask.new_like(self, output) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> SegmentationMask: + interpolation, fill # unused + output = self._F.affine_segmentation_mask( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return SegmentationMask.new_like(self, output) + + def perspective(self, perspective_coeffs, *, interpolation, fill) -> SegmentationMask: + interpolation, fill # unused + output = self._F.perspective_segmentation_mask(self, perspective_coeffs) + return SegmentationMask.new_like(self, output) + + def erase(self, *args) -> SegmentationMask: + raise TypeError("Erase transformation does not support segmentation masks") + + def mixup(self, *args) -> SegmentationMask: + raise TypeError("Mixup transformation does not support segmentation masks") + + def cutmix(self, *args) -> SegmentationMask: + raise TypeError("Cutmix transformation does not support segmentation masks") diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5edd18890a8..d5777560089 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,6 +8,7 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, + RandomCrop, CenterCrop, RandomResizedCrop, FiveCrop, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 82c5f52f1dc..4ad9c7302b7 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,12 +3,13 @@ import warnings from typing import Any, Dict, Tuple +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor +from ._utils import query_image, get_image_dimensions, has_all class RandomErasing(_RandomApplyTransform): @@ -51,7 +52,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if value is not None and not (len(value) in (1, img_c)): raise ValueError( - f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w @@ -82,59 +83,45 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: else: i, j, h, w, v = 0, 0, img_h, img_w, image - return dict(zip("ijhwv", (i, j, h, w, v))) + return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.erase_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.erase_image_tensor(input, **params) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.erase(**params) + elif isinstance(inpt, PIL.Image.Image): + # Shouldn't we implement a fallback to tensor ? + raise RuntimeError("Not implemented") + elif isinstance(inpt, torch.Tensor): + return F.erase_image_tensor(inpt, **params) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) - - -class RandomMixup(Transform): +class _BaseMixupCutmix(Transform): def __init__(self, *, alpha: float) -> None: super().__init__() self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] + if not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) + + +class RandomMixup(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.mixup_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.mixup_one_hot_label(input, **params) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.mixup(**params) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) + return inpt -class RandomCutmix(Transform): - def __init__(self, *, alpha: float) -> None: - super().__init__() - self.alpha = alpha - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - +class RandomCutmix(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) @@ -158,20 +145,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.cutmix_image_tensor(input, box=params["box"]) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.cutmix(**params) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) + return inpt diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7fc62423ab8..3c43166ad96 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -106,9 +106,7 @@ def _apply_image_transform( if transform_id == "Identity": return image elif transform_id == "ShearX": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, 0], @@ -118,9 +116,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "ShearY": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, 0], @@ -130,9 +126,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "TranslateX": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[int(magnitude), 0], @@ -142,9 +136,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "TranslateY": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, int(magnitude)], @@ -154,46 +146,25 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "Rotate": - return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude) + return F.rotate(image, angle=magnitude) elif transform_id == "Brightness": - return self._dispatch_image_kernels( - F.adjust_brightness_image_tensor, - F.adjust_brightness_image_pil, - image, - brightness_factor=1.0 + magnitude, - ) + return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) elif transform_id == "Color": - return self._dispatch_image_kernels( - F.adjust_saturation_image_tensor, - F.adjust_saturation_image_pil, - image, - saturation_factor=1.0 + magnitude, - ) + return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) elif transform_id == "Contrast": - return self._dispatch_image_kernels( - F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude - ) + return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) elif transform_id == "Sharpness": - return self._dispatch_image_kernels( - F.adjust_sharpness_image_tensor, - F.adjust_sharpness_image_pil, - image, - sharpness_factor=1.0 + magnitude, - ) + return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) elif transform_id == "Posterize": - return self._dispatch_image_kernels( - F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude) - ) + return F.posterize(image, bits=int(magnitude)) elif transform_id == "Solarize": - return self._dispatch_image_kernels( - F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude - ) + return F.solarize(image, threshold=magnitude) elif transform_id == "AutoContrast": - return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image) + return F.autocontrast(image) elif transform_id == "Equalize": - return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image) + return F.equalize(image) elif transform_id == "Invert": - return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image) + return F.invert(image) else: raise ValueError(f"No transform available for {transform_id}") diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 960020baff8..60fe46ed9ea 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,5 +1,4 @@ import collections.abc -import functools from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar import PIL.Image @@ -55,74 +54,52 @@ def _check_input( def _image_transform( self, - input: T, + inpt: T, *, kernel_tensor: Callable[..., torch.Tensor], kernel_pil: Callable[..., PIL.Image.Image], **kwargs: Any, ) -> T: - if isinstance(input, features.Image): - output = kernel_tensor(input, **kwargs) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return kernel_tensor(input, **kwargs) - elif isinstance(input, PIL.Image.Image): - return kernel_pil(input, **kwargs) # type: ignore[no-any-return] + if isinstance(inpt, features.Image): + output = kernel_tensor(inpt, **kwargs) + return features.Image.new_like(inpt, output) + elif is_simple_tensor(inpt): + return kernel_tensor(inpt, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return kernel_pil(inpt, **kwargs) # type: ignore[no-any-return] else: raise RuntimeError + @staticmethod + def _generate_value(left: float, right: float) -> float: + return float(torch.distributions.Uniform(left, right).sample()) + def _get_params(self, sample: Any) -> Dict[str, Any]: - image_transforms = [] - if self.brightness is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_brightness_image_tensor, - kernel_pil=F.adjust_brightness_image_pil, - brightness_factor=float( - torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample() - ), - ) - ) - if self.contrast is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_contrast_image_tensor, - kernel_pil=F.adjust_contrast_image_pil, - contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()), - ) - ) - if self.saturation is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_saturation_image_tensor, - kernel_pil=F.adjust_saturation_image_pil, - saturation_factor=float( - torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample() - ), - ) - ) - if self.hue is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_hue_image_tensor, - kernel_pil=F.adjust_hue_image_pil, - hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()), - ) - ) - - return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))]) - - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): - return input - - for transform in params["image_transforms"]: - input = transform(input) - return input + fn_idx = torch.randperm(4) + + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + + return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + output = inpt + brightness_factor = params["brightness_factor"] + contrast_factor = params["contrast_factor"] + saturation_factor = params["saturation_factor"] + hue_factor = params["hue_factor"] + for fn_id in params["fn_idx"]: + if fn_id == 0 and brightness_factor is not None: + output = F.adjust_brightness(output, brightness_factor=brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + output = F.adjust_contrast(output, contrast_factor=contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + output = F.adjust_saturation(output, saturation_factor=saturation_factor) + elif fn_id == 3 and hue_factor is not None: + output = F.adjust_hue(output, hue_factor=hue_factor) + return output class _RandomChannelShuffle(Transform): @@ -131,19 +108,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_channels, _, _ = get_image_dimensions(image) return dict(permutation=torch.randperm(num_channels)) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): + return inpt - image = input - if isinstance(input, PIL.Image.Image): + image = inpt + if isinstance(inpt, PIL.Image.Image): image = _F.pil_to_tensor(image) output = image[..., params["permutation"], :, :] - if isinstance(input, features.Image): - output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER) - elif isinstance(input, PIL.Image.Image): + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) + elif isinstance(inpt, PIL.Image.Image): output = _F.to_pil_image(output) return output @@ -175,33 +152,25 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: contrast_before=torch.rand(()) < 0.5, ) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness"]: - input = self._brightness(input) + inpt = self._brightness(inpt) if params["contrast1"] and params["contrast_before"]: - input = self._contrast(input) + inpt = self._contrast(inpt) if params["saturation"]: - input = self._saturation(input) + inpt = self._saturation(inpt) if params["saturation"]: - input = self._saturation(input) + inpt = self._saturation(inpt) if params["contrast2"] and not params["contrast_before"]: - input = self._contrast(input) + inpt = self._contrast(inpt) if params["channel_shuffle"]: - input = self._channel_shuffle(input) - return input + inpt = self._channel_shuffle(inpt) + return inpt class RandomEqualize(_RandomApplyTransform): def __init__(self, p: float = 0.5): super().__init__(p=p) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.equalize_image_tensor(input) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.equalize_image_tensor(input) - elif isinstance(input, PIL.Image.Image): - return F.equalize_image_pil(input) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.equalize(inpt) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 0487a71416e..e7e07701cd4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -2,56 +2,33 @@ import math import numbers import warnings -from typing import Any, Dict, List, Union, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms.functional import pil_to_tensor, InterpolationMode -from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int + +# TODO: refactor _parse_pad_padding into +# torchvision.transforms.functional and update F_t.pad and F_pil.pad +# and remove redundancy +from torchvision.transforms.functional_tensor import _parse_pad_padding +from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor +from ._utils import query_image, get_image_dimensions, has_any class RandomHorizontalFlip(_RandomApplyTransform): - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.horizontal_flip_image_tensor(input) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.horizontal_flip_segmentation_mask(input) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return features.BoundingBox.new_like(input, output) - elif isinstance(input, PIL.Image.Image): - return F.horizontal_flip_image_pil(input) - elif is_simple_tensor(input): - return F.horizontal_flip_image_tensor(input) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.horizontal_flip(inpt) class RandomVerticalFlip(_RandomApplyTransform): - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.vertical_flip_image_tensor(input) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.vertical_flip_segmentation_mask(input) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return features.BoundingBox.new_like(input, output) - elif isinstance(input, PIL.Image.Image): - return F.vertical_flip_image_pil(input) - elif is_simple_tensor(input): - return F.vertical_flip_image_tensor(input) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.vertical_flip(inpt) class Resize(Transform): @@ -59,27 +36,23 @@ def __init__( self, size: Union[int, Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, ) -> None: super().__init__() self.size = [size] if isinstance(size, int) else list(size) self.interpolation = interpolation + self.max_size = max_size + self.antialias = antialias - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.resize_segmentation_mask(input, self.size) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.resize_bounding_box(input, self.size, image_size=input.image_size) - return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size))) - elif isinstance(input, PIL.Image.Image): - return F.resize_image_pil(input, self.size, interpolation=self.interpolation) - elif is_simple_tensor(input): - return F.resize_image_tensor(input, self.size, interpolation=self.interpolation) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) class CenterCrop(Transform): @@ -87,22 +60,86 @@ def __init__(self, output_size: List[int]): super().__init__() self.output_size = output_size - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.center_crop_image_tensor(input, self.output_size) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.center_crop_image_tensor(input, self.output_size) - elif isinstance(input, PIL.Image.Image): - return F.center_crop_image_pil(input, self.output_size) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.center_crop(inpt, output_size=self.output_size) + + +class RandomCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[float, Sequence[float]] = 0.0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ): + super().__init__() + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + self._pad_op = Pad(padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image + image = query_image(sample) + _, height, width = get_image_dimensions(image) + out_height, out_width = self.size + + if height + 1 < out_height or width + 1 < out_width: + raise ValueError( + f"Required crop size {(out_height, out_width)} is larger then input image size {(height, width)}" + ) + + if height == out_height and width == out_width: + return dict(top=0, left=0, height=height, width=width) + + i = torch.randint(0, height - out_height + 1, size=(1,)).item() + j = torch.randint(0, width - out_width + 1, size=(1,)).item() + return dict(top=i, left=j, height=out_height, width=out_width) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.crop(inpt, **params) def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) + # TODO: main difficulties implementing this op: + # 1) unstructured inputs and why we need to call: sample = inputs if len(inputs) > 1 else inputs[0] ? + # 2) how to call F.op efficiently on inputs ? + # + # We can make inputs flatten using from torch.utils._pytree import tree_flatten, tree_unflatten + # Such that inputs -> flat_inputs = [obj1, obj2, obj3, ...] + + raise RuntimeError("Not yet implemented") + + params = self._get_params(inputs) + + # sample = inputs if len(inputs) > 1 else inputs[0] + # return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) + + if self.padding is not None: + self._pad_op.padding = self.padding + inputs = self._pad_op(*inputs) + + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image + image = query_image(inputs) + _, height, width = get_image_dimensions(image) + + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + + return ... class RandomResizedCrop(Transform): @@ -112,6 +149,7 @@ def __init__( scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -125,20 +163,16 @@ def __init__( if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") - # Backward compatibility with integer value - if isinstance(interpolation, int): - warnings.warn( - "Argument interpolation should be of type InterpolationMode instead of int. " - "Please, use InterpolationMode enum." - ) - interpolation = _interpolation_modes_from_int(interpolation) - self.size = size self.scale = scale self.ratio = ratio self.interpolation = interpolation + self.antialias = antialias def _get_params(self, sample: Any) -> Dict[str, Any]: + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image image = query_image(sample) _, height, width = get_image_dimensions(image) area = height * width @@ -177,24 +211,10 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.resized_crop_image_tensor( - input, **params, size=list(self.size), interpolation=self.interpolation - ) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) - elif isinstance(input, PIL.Image.Image): - return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) - else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resized_crop( + inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + ) class MultiCropResult(list): @@ -213,19 +233,19 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.five_crop_image_tensor(input, self.size) - return MultiCropResult(features.Image.new_like(input, o) for o in output) - elif is_simple_tensor(input): - return MultiCropResult(F.five_crop_image_tensor(input, self.size)) - elif isinstance(input, PIL.Image.Image): - return MultiCropResult(F.five_crop_image_pil(input, self.size)) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image): + output = F.five_crop_image_tensor(inpt, self.size) + return MultiCropResult(features.Image.new_like(inpt, o) for o in output) + elif isinstance(inpt, PIL.Image.Image): + return MultiCropResult(F.five_crop_image_pil(inpt, self.size)) + elif isinstance(inpt, torch.Tensor): + return MultiCropResult(F.five_crop_image_tensor(inpt, self.size)) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) @@ -237,26 +257,26 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) - return MultiCropResult(features.Image.new_like(input, o) for o in output) - elif is_simple_tensor(input): - return MultiCropResult(F.ten_crop_image_tensor(input, self.size)) - elif isinstance(input, PIL.Image.Image): - return MultiCropResult(F.ten_crop_image_pil(input, self.size)) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image): + output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) + return MultiCropResult(features.Image.new_like(inpt, o) for o in output) + elif isinstance(inpt, PIL.Image.Image): + return MultiCropResult(F.ten_crop_image_pil(inpt, self.size)) + elif isinstance(inpt, torch.Tensor): + return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size)) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) class BatchMultiCrop(Transform): - def forward(self, *inputs: Any) -> Any: + def forward(self, *inpts: Any) -> Any: # This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one # significant difference: # Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from @@ -280,7 +300,7 @@ def apply_recursively(obj: Any) -> Any: else: return obj - return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) + return apply_recursively(inpts if len(inpts) > 1 else inpts[0]) class Pad(Transform): @@ -305,49 +325,13 @@ def __init__( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) + padding = _parse_pad_padding(padding) self.padding = padding self.fill = fill self.padding_mode = padding_mode - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image) or is_simple_tensor(input): - # PyTorch's pad supports only integers on fill. So we need to overwrite the colour - output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant") - - left, top, right, bottom = params["padding"] - fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1) - - if top > 0: - output[..., :top, :] = fill - if left > 0: - output[..., :, :left] = fill - if bottom > 0: - output[..., -bottom:, :] = fill - if right > 0: - output[..., :, -right:] = fill - - if isinstance(input, features.Image): - output = features.Image.new_like(input, output) - - return output - elif isinstance(input, PIL.Image.Image): - return F.pad_image_pil( - input, - params["padding"], - fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]), - padding_mode="constant", - ) - elif isinstance(input, features.BoundingBox): - output = F.pad_bounding_box(input, params["padding"], format=input.format) - - left, top, right, bottom = params["padding"] - height, width = input.image_size - height += top + bottom - width += left + right - - return features.BoundingBox.new_like(input, output, image_size=(height, width)) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) class RandomZoomOut(_RandomApplyTransform): @@ -364,6 +348,8 @@ def __init__( if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid canvas side range provided {side_range}.") + self._pad_op = Pad(0, padding_mode="constant") + def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) orig_c, orig_h, orig_w = get_image_dimensions(image) @@ -385,6 +371,135 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding, fill=fill) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - transform = Pad(**params, padding_mode="constant") - return transform(input) + def forward(self, *inputs: Any) -> Any: + params = self._get_params(inputs) + self._pad_op.padding = params["padding"] + self._pad_op.fill = params["fill"] + return self._pad_op(*inputs) + + +class RandomRotation(Transform): + def __init__( + self, + degrees, + interpolation=InterpolationMode.NEAREST, + expand=False, + fill=0, + center=None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.interpolation = interpolation + self.expand = expand + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, sample: Any) -> Dict[str, Any]: + angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + return dict(angle=angle) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rotate( + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + fill=self.fill, + center=self.center, + ) + + +class RandomAffine(Transform): + def __init__( + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + center=None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = interpolation + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, sample: Any) -> Dict[str, Any]: + + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, height, width = get_image_dimensions(image) + + angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + if self.translate is not None: + max_dx = float(self.translate[0] * width) + max_dy = float(self.translate[1] * height) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translations = (tx, ty) + else: + translations = (0, 0) + + if self.scale is not None: + scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if self.shear is not None: + shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) + if len(self.shear) == 4: + shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) + + shear = (shear_x, shear_y) + return dict(angle=angle, translations=translations, scale=scale, shear=shear) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.affine( + inpt, + **params, + interpolation=self.interpolation, + fill=self.fill, + center=self.center, + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2a6c7dce516..a8c17577a56 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,72 +7,89 @@ from ._augment import ( erase_image_tensor, - mixup_image_tensor, - mixup_one_hot_label, - cutmix_image_tensor, - cutmix_one_hot_label, ) from ._color import ( + adjust_brightness, adjust_brightness_image_tensor, adjust_brightness_image_pil, + adjust_contrast, adjust_contrast_image_tensor, adjust_contrast_image_pil, + adjust_saturation, adjust_saturation_image_tensor, adjust_saturation_image_pil, + adjust_sharpness, adjust_sharpness_image_tensor, adjust_sharpness_image_pil, + adjust_hue, + adjust_hue_image_tensor, + adjust_hue_image_pil, + adjust_gamma, + adjust_gamma_image_tensor, + adjust_gamma_image_pil, + posterize, posterize_image_tensor, posterize_image_pil, + solarize, solarize_image_tensor, solarize_image_pil, + autocontrast, autocontrast_image_tensor, autocontrast_image_pil, + equalize, equalize_image_tensor, equalize_image_pil, + invert, invert_image_tensor, invert_image_pil, - adjust_hue_image_tensor, - adjust_hue_image_pil, - adjust_gamma_image_tensor, - adjust_gamma_image_pil, ) from ._geometry import ( + horizontal_flip, horizontal_flip_bounding_box, horizontal_flip_image_tensor, horizontal_flip_image_pil, horizontal_flip_segmentation_mask, + resize, resize_bounding_box, resize_image_tensor, resize_image_pil, resize_segmentation_mask, + center_crop, center_crop_bounding_box, center_crop_segmentation_mask, center_crop_image_tensor, center_crop_image_pil, + resized_crop, resized_crop_bounding_box, resized_crop_image_tensor, resized_crop_image_pil, resized_crop_segmentation_mask, + affine, affine_bounding_box, affine_image_tensor, affine_image_pil, affine_segmentation_mask, + rotate, rotate_bounding_box, rotate_image_tensor, rotate_image_pil, rotate_segmentation_mask, + pad, pad_bounding_box, pad_image_tensor, pad_image_pil, pad_segmentation_mask, + crop, crop_bounding_box, crop_image_tensor, crop_image_pil, crop_segmentation_mask, + perspective, perspective_bounding_box, perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, vertical_flip_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 5004ac550dd..3920d1b3065 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,45 +1,13 @@ -from typing import Tuple - -import torch from torchvision.transforms import functional_tensor as _FT erase_image_tensor = _FT.erase -def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: - input = input.clone() - return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) - - -def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - return _mixup_tensor(image_batch, -4, lam) - - -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup_tensor(one_hot_label_batch, -2, lam) - - -def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch = image_batch.clone() - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") +# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. +# Like the mixup and cutmix stuff - return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) +# This function is copy-pasted to Image and OneHotLabel and may be refactored +# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: +# input = input.clone() +# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index fa632d7df58..f8016b43a36 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,34 +1,171 @@ +from typing import Any + +import PIL.Image +import torch +from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness + +def adjust_brightness(inpt: Any, brightness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_brightness(brightness_factor=brightness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + else: + return inpt + + adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation + +def adjust_saturation(inpt: Any, saturation_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_saturation(saturation_factor=saturation_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + else: + return inpt + + adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast + +def adjust_contrast(inpt: Any, contrast_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_contrast(contrast_factor=contrast_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + else: + return inpt + + adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness + +def adjust_sharpness(inpt: Any, sharpness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + else: + return inpt + + +adjust_hue_image_tensor = _FT.adjust_hue +adjust_hue_image_pil = _FP.adjust_hue + + +def adjust_hue(inpt: Any, hue_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_hue(hue_factor=hue_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_hue_image_pil(inpt, hue_factor=hue_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + else: + return inpt + + +adjust_gamma_image_tensor = _FT.adjust_gamma +adjust_gamma_image_pil = _FP.adjust_gamma + + +def adjust_gamma(inpt: Any, gamma: float, gain: float = 1) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_gamma(gamma=gamma, gain=gain) + elif isinstance(inpt, PIL.Image.Image): + return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) + elif isinstance(inpt, torch.Tensor): + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + else: + return inpt + + posterize_image_tensor = _FT.posterize posterize_image_pil = _FP.posterize + +def posterize(inpt: Any, bits: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.posterize(bits=bits) + elif isinstance(inpt, PIL.Image.Image): + return posterize_image_pil(inpt, bits=bits) + elif isinstance(inpt, torch.Tensor): + return posterize_image_tensor(inpt, bits=bits) + else: + return inpt + + solarize_image_tensor = _FT.solarize solarize_image_pil = _FP.solarize + +def solarize(inpt: Any, threshold: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.solarize(threshold=threshold) + elif isinstance(inpt, PIL.Image.Image): + return solarize_image_pil(inpt, threshold=threshold) + elif isinstance(inpt, torch.Tensor): + return solarize_image_tensor(inpt, threshold=threshold) + else: + return inpt + + autocontrast_image_tensor = _FT.autocontrast autocontrast_image_pil = _FP.autocontrast + +def autocontrast(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.autocontrast() + elif isinstance(inpt, PIL.Image.Image): + return autocontrast_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return autocontrast_image_tensor(inpt) + else: + return inpt + + equalize_image_tensor = _FT.equalize equalize_image_pil = _FP.equalize + +def equalize(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.equalize() + elif isinstance(inpt, PIL.Image.Image): + return equalize_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return equalize_image_tensor(inpt) + else: + return inpt + + invert_image_tensor = _FT.invert invert_image_pil = _FP.invert -adjust_hue_image_tensor = _FT.adjust_hue -adjust_hue_image_pil = _FP.adjust_hue -adjust_gamma_image_tensor = _FT.adjust_gamma -adjust_gamma_image_pil = _FP.adjust_gamma +def invert(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.invert() + elif isinstance(inpt, PIL.Image.Image): + return invert_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return invert_image_tensor(inpt) + else: + return inpt diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 95e094ad798..b6704b96328 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,6 +1,6 @@ import numbers import warnings -from typing import Tuple, List, Optional, Sequence, Union +from typing import Any, Tuple, List, Optional, Sequence, Union import PIL.Image import torch @@ -40,12 +40,58 @@ def horizontal_flip_bounding_box( ).view(shape) +def horizontal_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.horizontal_flip() + elif isinstance(inpt, PIL.Image.Image): + return horizontal_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return horizontal_flip_image_tensor(inpt) + else: + return inpt + + +vertical_flip_image_tensor = _FT.vflip +vertical_flip_image_pil = _FP.vflip + + +def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(segmentation_mask) + + +def vertical_flip_bounding_box( + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(shape) + + +def vertical_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.vertical_flip() + elif isinstance(inpt, PIL.Image.Image): + return vertical_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return vertical_flip_image_tensor(inpt) + else: + return inpt + + def resize_image_tensor( image: torch.Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: bool = False, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) @@ -87,28 +133,25 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -vertical_flip_image_tensor = _FT.vflip -vertical_flip_image_pil = _FP.vflip - - -def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(segmentation_mask) - - -def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] -) -> torch.Tensor: - shape = bounding_box.shape - - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) - - bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] - - return convert_bounding_box_format( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).view(shape) +def resize( + inpt: Any, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> Any: + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, PIL.Image.Image): + if antialias is not None and not antialias: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + else: + return inpt def _affine_parse_args( @@ -323,6 +366,27 @@ def affine_segmentation_mask( ) +def affine( + inpt: Any, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + kwargs = dict(translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center) + if isinstance(inpt, features._Feature): + return inpt.affine(angle, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return affine_image_pil(inpt, angle, **kwargs) + elif isinstance(inpt, torch.Tensor): + return affine_image_tensor(inpt, angle, **kwargs) + else: + return inpt + + def rotate_image_tensor( img: torch.Tensor, angle: float, @@ -402,6 +466,30 @@ def rotate_segmentation_mask( ) +def rotate( + inpt: Any, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + kwargs = dict( + interpolation=interpolation, + expand=expand, + fill=fill, + center=center, + ) + if isinstance(inpt, features._Feature): + return inpt.rotate(angle, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return rotate_image_pil(inpt, angle, **kwargs) + elif isinstance(inpt, torch.Tensor): + return rotate_image_tensor(inpt, angle, **kwargs) + else: + return inpt + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad @@ -436,6 +524,21 @@ def pad_bounding_box( return bounding_box +def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: + kwargs = dict( + fill=fill, + padding_mode=padding_mode, + ) + if isinstance(inpt, features._Feature): + return inpt.pad(padding, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return pad_image_pil(inpt, padding, **kwargs) + elif isinstance(inpt, torch.Tensor): + return pad_image_tensor(inpt, padding, **kwargs) + else: + return inpt + + crop_image_tensor = _FT.crop crop_image_pil = _FP.crop @@ -463,6 +566,17 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, return crop_image_tensor(img, top, left, height, width) +def crop(inpt: Any, top: int, left: int, height: int, width: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.crop(top, left, height, width) + elif isinstance(inpt, PIL.Image.Image): + return crop_image_pil(inpt, top, left, height, width) + elif isinstance(inpt, torch.Tensor): + return crop_image_tensor(inpt, top, left, height, width) + else: + return inpt + + def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], @@ -570,6 +684,23 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective( + inpt: Any, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> Any: + kwargs = dict(interpolation=interpolation, fill=fill) + if isinstance(inpt, features._Feature): + return inpt.perspective(perspective_coeffs, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return perspective_image_pil(inpt, perspective_coeffs, **kwargs) + elif isinstance(inpt, torch.Tensor): + return perspective_image_tensor(inpt, perspective_coeffs, **kwargs) + else: + return inpt + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] @@ -643,6 +774,17 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) +def center_crop(inpt: Any, output_size: List[int]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.center_crop(output_size) + elif isinstance(inpt, PIL.Image.Image): + return center_crop_image_pil(inpt, output_size) + elif isinstance(inpt, torch.Tensor): + return center_crop_image_tensor(inpt, output_size) + else: + return inpt + + def resized_crop_image_tensor( img: torch.Tensor, top: int, @@ -651,9 +793,10 @@ def resized_crop_image_tensor( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, ) -> torch.Tensor: img = crop_image_tensor(img, top, left, height, width) - return resize_image_tensor(img, size, interpolation=interpolation) + return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias) def resized_crop_image_pil( @@ -694,6 +837,29 @@ def resized_crop_segmentation_mask( return resize_segmentation_mask(mask, size) +def resized_crop( + inpt: Any, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, +) -> Any: + kwargs = dict(size=size, interpolation=interpolation) + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resized_crop(top, left, height, width, antialias=antialias, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return resized_crop_image_pil(inpt, top, left, height, width, **kwargs) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resized_crop_image_tensor(inpt, top, left, height, width, antialias=antialias, **kwargs) + else: + return inpt + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = [int(size), int(size)] diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index acc8d3ae3e1..35618da9339 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,16 +430,13 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - antialias: Optional[bool] = None, + antialias: bool = False, ) -> Tensor: _assert_image_tensor(img) if isinstance(size, tuple): size = list(size) - if antialias is None: - antialias = False - if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")