From da4d047a2c34938694e16d01614ba1359ecacd69 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Mar 2022 10:28:12 +0100 Subject: [PATCH] port RandomPhotoMetricDistort to prototype transforms --- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_color.py | 68 +++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 647f3937ed5..7b74de2a400 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -4,7 +4,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix -from ._color import ColorJitter +from ._color import ColorJitter, RandomPhotometricDistort from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 7aeec29fda0..cf3ceaf7ede 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -6,8 +6,9 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms import functional as _F -from ._utils import is_simple_tensor +from ._utils import is_simple_tensor, get_image_dimensions, query_image T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) @@ -120,5 +121,70 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: for transform in params["image_transforms"]: input = transform(input) + return input + + +class _RandomChannelShuffle(Transform): + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + num_channels, _, _ = get_image_dimensions(image) + return dict(permutation=torch.randperm(num_channels)) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): + return input + + image = input + if isinstance(input, PIL.Image.Image): + image = _F.pil_to_tensor(image) + + output = image[..., params["permutation"], :, :] + + if isinstance(input, features.Image): + output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER) + elif isinstance(input, PIL.Image.Image): + output = _F.to_pil_image(output) + + return output + +class RandomPhotometricDistort(Transform): + def __init__( + self, + contrast: Tuple[float, float] = (0.5, 1.5), + saturation: Tuple[float, float] = (0.5, 1.5), + hue: Tuple[float, float] = (-0.05, 0.05), + brightness: Tuple[float, float] = (0.875, 1.125), + p: float = 0.5, + ): + super().__init__() + self._brightness = ColorJitter(brightness=brightness) + self._contrast = ColorJitter(contrast=contrast) + self._hue = ColorJitter(hue=hue) + self._saturation = ColorJitter(saturation=saturation) + self._channel_shuffle = _RandomChannelShuffle() + self.p = p + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict( + zip( + ["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"], + torch.rand(6) < self.p, + ), + contrast_before=torch.rand(()) < 0.5, + ) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if params["brightness"]: + input = self._brightness(input) + if params["contrast1"] and params["contrast_before"]: + input = self._contrast(input) + if params["saturation"]: + input = self._saturation(input) + if params["saturation"]: + input = self._saturation(input) + if params["contrast2"] and not params["contrast_before"]: + input = self._contrast(input) + if params["channel_shuffle"]: + input = self._channel_shuffle(input) return input