diff --git a/test/common_utils.py b/test/common_utils.py index 5936ae1f713..546893b0cb3 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -144,7 +144,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu return batch_tensor -assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) +assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0, check_stride=False) def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): diff --git a/test/test_transforms.py b/test/test_transforms.py index c5cc80ef87e..363808ac4e8 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -205,23 +205,23 @@ def test_to_tensor(self, channels): input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) img = transforms.ToPILImage()(input_data) output = trans(img) - torch.testing.assert_close(output, input_data) + torch.testing.assert_close(output, input_data, check_stride=False) ndarray = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) output = trans(ndarray) expected_output = ndarray.transpose((2, 0, 1)) / 255.0 - torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False) + torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False) ndarray = np_rng.rand(height, width, channels).astype(np.float32) output = trans(ndarray) expected_output = ndarray.transpose((2, 0, 1)) - torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False) + torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False) # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() img = transforms.ToPILImage()(input_data.mul(255)).convert('1') output = trans(img) - torch.testing.assert_close(input_data, output, check_dtype=False) + torch.testing.assert_close(input_data, output, check_dtype=False, check_stride=False) def test_to_tensor_errors(self): height, width = 4, 4 @@ -261,7 +261,7 @@ def test_pil_to_tensor(self, channels): input_data = torch.ByteTensor(channels, height, width).random_(0, 255) img = transforms.ToPILImage()(input_data) output = trans(img) - torch.testing.assert_close(input_data, output) + torch.testing.assert_close(input_data, output, check_stride=False) input_data = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) img = transforms.ToPILImage()(input_data) @@ -273,13 +273,13 @@ def test_pil_to_tensor(self, channels): img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() output = trans(img) # HWC -> CHW expected_output = (input_data * 255).byte() - torch.testing.assert_close(output, expected_output) + torch.testing.assert_close(output, expected_output, check_stride=False) # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() img = transforms.ToPILImage()(input_data.mul(255)).convert('1') output = trans(img).view(torch.uint8).bool().to(torch.uint8) - torch.testing.assert_close(input_data, output) + torch.testing.assert_close(input_data, output, check_stride=False) def test_pil_to_tensor_errors(self): height, width = 4, 4 @@ -424,10 +424,10 @@ def test_pad(self): h_padded = result[:, :padding, :] w_padded = result[:, :, :padding] torch.testing.assert_close( - h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps + h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False, ) torch.testing.assert_close( - w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps + w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False, ) pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img)) @@ -528,9 +528,9 @@ def test_randomness(fn, trans, config, p): num_samples = 250 counts = 0 for _ in range(num_samples): - tranformation = trans(p=p, **config) - tranformation.__repr__() - out = tranformation(img) + transformation = trans(p=p, **config) + transformation.__repr__() + out = transformation(img) if out == inv_img: counts += 1 @@ -583,7 +583,7 @@ def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_outpu img = transform(img_data) assert img.mode == expected_mode - torch.testing.assert_close(expected_output, to_tensor(img).numpy()) + torch.testing.assert_close(expected_output, to_tensor(img).numpy(), check_stride=False) def test_1_channel_float_tensor_to_pil_image(self): img_data = torch.Tensor(1, 4, 4).uniform_() @@ -621,7 +621,7 @@ def test_2_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(2): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_2_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() @@ -725,7 +725,7 @@ def test_3_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(3): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_3_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() @@ -782,7 +782,7 @@ def test_4_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(4): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_4_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() @@ -1532,7 +1532,7 @@ def test_random_crop(): t = transforms.RandomCrop(48) img = torch.ones(3, 32, 32) - with pytest.raises(ValueError, match=r"Required crop size .+ is larger then input image size .+"): + with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"): t(img) @@ -1659,7 +1659,7 @@ def test_random_erasing(): img = torch.ones(3, 128, 128) t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ]) aspect_ratio = h / w # Add some tolerance due to the rounding and int conversion used in the transform tol = 0.05 @@ -1669,7 +1669,7 @@ def test_random_erasing(): random.seed(42) trial = 1000 for _ in range(trial): - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ]) aspect_ratios.append(h / w) count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) @@ -1730,7 +1730,7 @@ def test_randomperspective(): to_pil_image = transforms.ToPILImage() img = to_pil_image(img) perp = transforms.RandomPerspective() - startpoints, endpoints = perp.get_params(width, height, 0.5) + startpoints, endpoints = perp.get_start_endpoints(width, height, 0.5) tr_img = F.perspective(img, startpoints, endpoints) tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints)) tr_img = F.to_tensor(tr_img) @@ -1767,7 +1767,7 @@ def test_randomperspective_fill(mode): pixel = (pixel,) assert pixel == tuple([fill] * num_bands) - startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5) + startpoints, endpoints = transforms.RandomPerspective.get_start_endpoints(width, height, 0.5) tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill) pixel = tr_img.getpixel((0, 0)) @@ -2062,7 +2062,7 @@ def test_random_affine(): t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40]) for _ in range(100): - angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, + angle, translations, scale, shear = t.get_params(img, t.degrees, t.translate, t.scale, t.shear, img_size=img.size) assert -10 < angle < 10 assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}" @@ -2094,5 +2094,60 @@ def test_random_affine(): assert t.interpolation == transforms.InterpolationMode.BILINEAR +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") +@pytest.mark.parametrize('trans, config', [ + (transforms.RandomInvert, {}), + (transforms.RandomPosterize, {"bits": 4}), + (transforms.RandomSolarize, {"threshold": 192}), + (transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), + (transforms.RandomAutocontrast, {}), + (transforms.RandomEqualize, {})]) +@pytest.mark.parametrize('p', (.5, .7)) +def test_reset_randomness(trans, config, p): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) + + num_samples = 250 + counts = 0 + for _ in range(num_samples): + transformation = trans(p=p, **config, reset_auto=False) + transformation.__repr__() + out1 = transformation(img) + assert out1 == transformation(img) + transformation.wipeout_() + out2 = transformation(img) + if out1 == out2: + counts += 1 + + p_repeat = p**2 + (1 - p)**2 + p_value = stats.binom_test(counts, num_samples, p=p_repeat) + random.setstate(random_state) + assert p_value > 0.0001, f'got counts={counts} for num_samples={num_samples}' + + +@pytest.mark.parametrize('trans, config', [ + (transforms.RandomCrop, {'size': 10}), + (transforms.RandomOrder, {"transforms": + [transforms.GaussianBlur(kernel_size=3, reset_auto=False), + transforms.RandomCrop(size=10, reset_auto=False)]}), + (transforms.RandomResizedCrop, {'size': 10}), + (transforms.ColorJitter, {}), + (transforms.RandomRotation, {'degrees': 120}), + (transforms.RandomAffine, {'degrees': 120, 'translate': (0.1, 0.1)}), + (transforms.RandomErasing, {}), + (transforms.GaussianBlur, {"kernel_size": 3})]) +def test_grouptransform(trans, config): + num_samples = 250 + for i in range(num_samples): + t = transforms.GroupTransform(trans(**config, reset_auto=False)) + assert t.stochastic + img = torch.arange(1024, dtype=torch.float).view(1, 32, 32).expand(3, 32, 32).contiguous() + mask = img[:1] + imgs = (img, mask) + imgs_out = t(imgs) + torch.testing.assert_close(imgs_out[0][0], imgs_out[1][0], rtol=1e-6, atol=1e-6, check_stride=False) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 954d5f5f064..656168326ab 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -6,6 +6,7 @@ from typing import Tuple, List, Optional import torch +from torch import nn from torch import Tensor try: @@ -17,7 +18,7 @@ from .functional import InterpolationMode, _interpolation_modes_from_int -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", +__all__ = ["Compose", "GroupTransform", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", @@ -25,7 +26,47 @@ "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] -class Compose: +class Transform(nn.Module): + stochastic = False + + def __init__(self, reset_auto: bool = True) -> None: + super().__init__() + self.reset_auto = reset_auto + self._initialized = False + + @property + def initialized(self): + return self._initialized + + def _call(self, input, *params): + raise NotImplementedError() + + def get_params(self, input): + return tuple() + + def reset_(self, input): + params = self.get_params(input) + if not isinstance(params, tuple): + self.params = (params,) + else: + self.params = params + self._initialized = True + return self.params + + def wipeout_(self): + self._initialized = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.initialized: + self.reset_(input) + params = self.params + output = self._call(input, *params) + if self.stochastic and self.reset_auto: + self.wipeout_() + return output + + +class Compose(Transform): """Composes several transforms together. This transform does not support torchscript. Please, see the note below. @@ -52,10 +93,37 @@ class Compose: """ - def __init__(self, transforms): + def __init__(self, transforms, reset_auto=True): + super().__init__(reset_auto=reset_auto) + if not isinstance(transforms, nn.Module) and ( + all(isinstance(t, nn.Module) for t in transforms) + ): + transforms = nn.Sequential(*transforms) + elif not all(isinstance(t, Transform) for t in transforms): + warnings.warn( + "All transforms should be of type torchvision.transforms.Transform. " + "Custom typed transforms will be forbidden in future releases." + ) self.transforms = transforms + for t in transforms: + assert isinstance( + t, Transform + ), f"class {type(t)} must inherit from trochvision.transforms.Transform" + + @property + def stochastic(self): + return any(t.stochastic for t in self.transforms if isinstance(t, Transform)) + + @property + def initialized(self): + return all(t.initialized for t in self.transforms if isinstance(t, Transform)) + + def wipeout_(self): + for t in self.transforms: + if isinstance(t, Transform): + t.wipeout_() - def __call__(self, img): + def _call(self, img): for t in self.transforms: img = t(img) return img @@ -69,7 +137,34 @@ def __repr__(self): return format_string -class ToTensor: +class GroupTransform(Transform): + def __init__(self, transform, reset_auto=True): + assert isinstance( + transform, Transform + ), "GroupTransform only accepts transforms of type Transform." + assert not transform.stochastic or not transform.reset_auto + super().__init__(reset_auto=reset_auto) + self.transform = transform + + @property + def stochastic(self): + return self.transform.stochastic + + @property + def initialized(self): + return self.transform.initialized + + def wipeout_(self): + return self.transform.wipeout_() + + def _call(self, imgs): + imgs = [self.transform(img) for img in imgs] + if self.reset_auto: + self.transform.wipeout_() + return imgs + + +class ToTensor(Transform): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. Converts a PIL Image or numpy.ndarray (H x W x C) in the range @@ -86,7 +181,7 @@ class ToTensor: .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation """ - def __call__(self, pic): + def _call(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -97,16 +192,16 @@ def __call__(self, pic): return F.to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + '()' -class PILToTensor: +class PILToTensor(Transform): """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). """ - def __call__(self, pic): + def _call(self, pic): """ Args: pic (PIL Image): Image to be converted to tensor. @@ -117,10 +212,10 @@ def __call__(self, pic): return F.pil_to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + '()' -class ConvertImageDtype(torch.nn.Module): +class ConvertImageDtype(Transform): """Convert a tensor image to the given ``dtype`` and scale the values accordingly This function does not support PIL Image. @@ -139,15 +234,15 @@ class ConvertImageDtype(torch.nn.Module): of the integer ``dtype``. """ - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() + def __init__(self, dtype: torch.dtype, reset_auto: bool=True) -> None: + super().__init__(reset_auto=reset_auto) self.dtype = dtype - def forward(self, image): + def _call(self, image): return F.convert_image_dtype(image, self.dtype) -class ToPILImage: +class ToPILImage(Transform): """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape @@ -164,10 +259,12 @@ class ToPILImage: .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ - def __init__(self, mode=None): + + def __init__(self, mode=None, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.mode = mode - def __call__(self, pic): + def _call(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. @@ -186,7 +283,7 @@ def __repr__(self): return format_string -class Normalize(torch.nn.Module): +class Normalize(Transform): """Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` @@ -210,7 +307,7 @@ def __init__(self, mean, std, inplace=False): self.std = std self.inplace = inplace - def forward(self, tensor: Tensor) -> Tensor: + def _call(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be normalized. @@ -224,7 +321,7 @@ def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) -class Resize(torch.nn.Module): +class Resize(Transform): """Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -288,7 +385,7 @@ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None self.interpolation = interpolation self.antialias = antialias - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be scaled. @@ -297,7 +394,7 @@ def forward(self, img): PIL Image or Tensor: Rescaled image. """ return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) - + def __repr__(self): interpolate_str = self.interpolation.value return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( @@ -314,7 +411,7 @@ def __init__(self, *args, **kwargs): super(Scale, self).__init__(*args, **kwargs) -class CenterCrop(torch.nn.Module): +class CenterCrop(Transform): """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -330,7 +427,7 @@ def __init__(self, size): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -344,7 +441,7 @@ def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class Pad(torch.nn.Module): +class Pad(Transform): """Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, @@ -396,12 +493,12 @@ def __init__(self, padding, fill=0, padding_mode="constant"): if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - + self.padding = padding self.fill = fill self.padding_mode = padding_mode - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be padded. @@ -414,9 +511,9 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ format(self.padding, self.fill, self.padding_mode) + - -class Lambda: +class Lambda(Transform): """Apply a user-defined lambda as a transform. This transform does not support torchscript. Args: @@ -424,32 +521,44 @@ class Lambda: """ def __init__(self, lambd): + super().__init__() if not callable(lambd): raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) self.lambd = lambd - def __call__(self, img): + def _call(self, img): return self.lambd(img) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + '()' -class RandomTransforms: +class RandomTransforms(Transform): + stochastic = True """Base class for a list of transformations with randomness Args: transforms (sequence): list of transformations """ - def __init__(self, transforms): + def __init__(self, transforms, reset_auto=True): + super().__init__(reset_auto=reset_auto) if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence") + if not all(reset_auto == t.reset_auto for t in transforms): + raise Exception( + "RandomTransform must have the same reset_auto attribute than provided transforms" + ) self.transforms = transforms - def __call__(self, *args, **kwargs): + def _call(self, *args, **kwargs): raise NotImplementedError() + def wipeout_(self): + super().wipeout_() + for t in self.transforms: + t.wipeout_() + def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: @@ -459,7 +568,8 @@ def __repr__(self): return format_string -class RandomApply(torch.nn.Module): +class RandomApply(Transform): + stochastic = True """Apply randomly a list of transformations with a given probability. .. note:: @@ -479,16 +589,23 @@ class RandomApply(torch.nn.Module): p (float): probability """ - def __init__(self, transforms, p=0.5): - super().__init__() + def __init__(self, transforms, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) + if isinstance(transforms, (list, tuple)): + transforms = nn.Sequential(*transforms) + if not isinstance(transforms, nn.Module): + raise TypeError("transfroms should be of type [List, Tuple, nn.Module]") self.transforms = transforms self.p = p - def forward(self, img): - if self.p < torch.rand(1): + def get_params(self, *args): + r = torch.rand(1) + return r + + def _call(self, img, r): + if self.p < r: return img - for t in self.transforms: - img = t(img) + img = self.transforms(img) return img def __repr__(self): @@ -502,25 +619,36 @@ def __repr__(self): class RandomOrder(RandomTransforms): + stochastic = True """Apply a list of transformations in a random order. This transform does not support torchscript. """ - def __call__(self, img): + + def get_params(self, *args): order = list(range(len(self.transforms))) random.shuffle(order) + return order + + def _call(self, img, order): for i in order: img = self.transforms[i](img) return img class RandomChoice(RandomTransforms): + stochastic = True """Apply single transformation randomly picked from a list. This transform does not support torchscript. """ - def __call__(self, img): + + def get_params(self, *args): t = random.choice(self.transforms) + return t + + def _call(self, img, t): return t(img) -class RandomCrop(torch.nn.Module): +class RandomCrop(Transform): + stochastic = True """Crop the given image at a random location. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, @@ -564,8 +692,7 @@ class RandomCrop(torch.nn.Module): will result in [2, 1, 1, 2, 3, 4, 4, 3] """ - @staticmethod - def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + def get_params(self, img: Tensor) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: @@ -576,11 +703,11 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = F._get_image_size(img) - th, tw = output_size + th, tw = self.size if h + 1 < th or w + 1 < tw: raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) + "Required crop size {} is larger than input image size {}".format((th, tw), (h, w)) ) if w == tw and h == th: @@ -590,8 +717,16 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int j = torch.randint(0, w - tw + 1, size=(1, )).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): - super().__init__() + def __init__( + self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode="constant", + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) self.size = tuple(_setup_size( size, error_msg="Please provide only two dimensions (h, w) for size." @@ -601,15 +736,11 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode + self.register_forward_pre_hook(RandomCrop._pad) - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ + def _pad(self, img_tuple): + assert len(img_tuple) == 1 + img = img_tuple[0] if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) @@ -622,8 +753,16 @@ def forward(self, img): if self.pad_if_needed and height < self.size[0]: padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) + return img - i, j, h, w = self.get_params(img, self.size) + def _call(self, img, i, j, h, w): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ return F.crop(img, i, j, h, w) @@ -631,7 +770,8 @@ def __repr__(self): return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) -class RandomHorizontalFlip(torch.nn.Module): +class RandomHorizontalFlip(Transform): + stochastic = True """Horizontally flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -641,11 +781,14 @@ class RandomHorizontalFlip(torch.nn.Module): p (float): probability of the image being flipped. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, *args): + return torch.rand(1) + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be flipped. @@ -653,7 +796,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly flipped image. """ - if torch.rand(1) < self.p: + if r < self.p: return F.hflip(img) return img @@ -661,7 +804,8 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomVerticalFlip(torch.nn.Module): +class RandomVerticalFlip(Transform): + stochastic = True """Vertically flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -671,11 +815,14 @@ class RandomVerticalFlip(torch.nn.Module): p (float): probability of the image being flipped. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, *args): + return torch.rand(1) + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be flipped. @@ -683,7 +830,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly flipped image. """ - if torch.rand(1) < self.p: + if r < self.p: return F.vflip(img) return img @@ -691,7 +838,8 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomPerspective(torch.nn.Module): +class RandomPerspective(Transform): + stochastic = True """Performs a random perspective transformation of the given image with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -708,8 +856,15 @@ class RandomPerspective(torch.nn.Module): image. Default is ``0``. If given a number, the value is used for all bands respectively. """ - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): - super().__init__() + def __init__( + self, + distortion_scale=0.5, + p=0.5, + interpolation=InterpolationMode.BILINEAR, + fill=0, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) self.p = p # Backward compatibility with integer value @@ -730,7 +885,18 @@ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode. self.fill = fill - def forward(self, img): + def get_params(self, img): + r = torch.rand(1) + if r < self.p: + width, height = F._get_image_size(img) + startpoints, endpoints = self.get_start_endpoints( + width, height, self.distortion_scale + ) + else: + startpoints, endpoints = None, None + return r, startpoints, endpoints + + def _call(self, img, r, startpoints, endpoints): """ Args: img (PIL Image or Tensor): Image to be Perspectively transformed. @@ -746,14 +912,14 @@ def forward(self, img): else: fill = [float(f) for f in fill] - if torch.rand(1) < self.p: - width, height = F._get_image_size(img) - startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + if r < self.p: return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @staticmethod - def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: + def get_start_endpoints( + width: int, height: int, distortion_scale: float + ) -> Tuple[List[List[int]], List[List[int]]]: """Get parameters for ``perspective`` for a random perspective transform. Args: @@ -791,7 +957,8 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomResizedCrop(torch.nn.Module): +class RandomResizedCrop(Transform): + stochastic = True """Crop a random portion of image and resize it to a given size. If the image is torch Tensor, it is expected @@ -820,9 +987,18 @@ class RandomResizedCrop(torch.nn.Module): """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=InterpolationMode.BILINEAR, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") @@ -843,9 +1019,8 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat self.scale = scale self.ratio = ratio - @staticmethod def get_params( - img: Tensor, scale: List[float], ratio: List[float] + self, img: Tensor, scale: List[float] = [], ratio: List[float] = [] ) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. @@ -858,6 +1033,10 @@ def get_params( tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ + if not len(scale): + scale = self.scale + if not len(ratio): + ratio = self.ratio width, height = F._get_image_size(img) area = height * width @@ -891,7 +1070,7 @@ def get_params( j = (width - w) // 2 return i, j, h, w - def forward(self, img): + def _call(self, img, i, j, h, w): """ Args: img (PIL Image or Tensor): Image to be cropped and resized. @@ -899,7 +1078,6 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly cropped and resized image. """ - i, j, h, w = self.get_params(img, self.scale, self.ratio) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) def __repr__(self): @@ -921,7 +1099,7 @@ def __init__(self, *args, **kwargs): super(RandomSizedCrop, self).__init__(*args, **kwargs) -class FiveCrop(torch.nn.Module): +class FiveCrop(Transform): """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -951,9 +1129,11 @@ class FiveCrop(torch.nn.Module): def __init__(self, size): super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -967,7 +1147,7 @@ def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class TenCrop(torch.nn.Module): +class TenCrop(Transform): """Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). If the image is torch Tensor, it is expected @@ -999,10 +1179,12 @@ class TenCrop(torch.nn.Module): def __init__(self, size, vertical_flip=False): super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) self.vertical_flip = vertical_flip - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -1016,7 +1198,7 @@ def __repr__(self): return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) -class LinearTransformation(torch.nn.Module): +class LinearTransformation(Transform): """Transform a tensor image with a square transformation matrix and a mean_vector computed offline. This transform does not support PIL Image. @@ -1053,7 +1235,7 @@ def __init__(self, transformation_matrix, mean_vector): self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def forward(self, tensor: Tensor) -> Tensor: + def _call(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be whitened. @@ -1064,13 +1246,17 @@ def forward(self, tensor: Tensor) -> Tensor: shape = tensor.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0]) + ) if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device) + ) flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) @@ -1084,7 +1270,8 @@ def __repr__(self): return format_string -class ColorJitter(torch.nn.Module): +class ColorJitter(Transform): + stochastic = True """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1105,13 +1292,14 @@ class ColorJitter(torch.nn.Module): Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - super().__init__() + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) + self.hue = self._check_input( + hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False + ) @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1133,12 +1321,16 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs value = None return value - @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + def get_params( + self, + img, + brightness: Optional[List[float]] = [], + contrast: Optional[List[float]] = [], + saturation: Optional[List[float]] = [], + hue: Optional[List[float]] = [], + ) -> Tuple[ + Tensor, Optional[float], Optional[float], Optional[float], Optional[float] + ]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1155,6 +1347,14 @@ def get_params(brightness: Optional[List[float]], tuple: The parameters used to apply the randomized transform along with their random order. """ + if not len(brightness): + brightness = self.brightness + if not len(contrast): + contrast = self.contrast + if not len(saturation): + saturation = self.saturation + if not len(hue): + hue = self.hue fn_idx = torch.randperm(4) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) @@ -1164,7 +1364,15 @@ def get_params(brightness: Optional[List[float]], return fn_idx, b, c, s, h - def forward(self, img): + def _call( + self, + img, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ): """ Args: img (PIL Image or Tensor): Input image. @@ -1172,8 +1380,6 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1196,7 +1402,8 @@ def __repr__(self): return format_string -class RandomRotation(torch.nn.Module): +class RandomRotation(Transform): + stochastic = True """Rotate the image by angle. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1225,9 +1432,16 @@ class RandomRotation(torch.nn.Module): """ def __init__( - self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None + self, + degrees, + interpolation=InterpolationMode.NEAREST, + expand=False, + center=None, + fill=0, + resample=None, + reset_auto=True, ): - super().__init__() + super().__init__(reset_auto=reset_auto) if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" @@ -1259,17 +1473,18 @@ def __init__( self.fill = fill - @staticmethod - def get_params(degrees: List[float]) -> float: + def get_params(self, img, degrees: List[float] = []) -> float: """Get parameters for ``rotate`` for a random rotation. Returns: float: angle parameter to be passed to ``rotate`` for random rotation. """ + if not len(degrees): + degrees = self.degrees angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) return angle - def forward(self, img): + def _call(self, img, angle): """ Args: img (PIL Image or Tensor): Image to be rotated. @@ -1283,7 +1498,6 @@ def forward(self, img): fill = [float(fill)] * F._get_image_num_channels(img) else: fill = [float(f) for f in fill] - angle = self.get_params(self.degrees) return F.rotate(img, angle, self.resample, self.expand, self.center, fill) @@ -1300,7 +1514,8 @@ def __repr__(self): return format_string -class RandomAffine(torch.nn.Module): +class RandomAffine(Transform): + stochastic = True """Random affine transformation of the image keeping center invariant. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1337,10 +1552,18 @@ class RandomAffine(torch.nn.Module): """ def __init__( - self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, - fillcolor=None, resample=None + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + fillcolor=None, + resample=None, + reset_auto=True, ): - super().__init__() + super().__init__(reset_auto=reset_auto) if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" @@ -1391,19 +1614,31 @@ def __init__( self.fillcolor = self.fill = fill - @staticmethod def get_params( - degrees: List[float], - translate: Optional[List[float]], - scale_ranges: Optional[List[float]], - shears: Optional[List[float]], - img_size: List[int] + self, + img, + degrees: List[float] = [], + translate: Optional[List[float]] = [], + scale_ranges: Optional[List[float]] = [], + shears: Optional[List[float]] = [], + img_size: List[int] = [], ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: """Get parameters for affine transformation Returns: params to be passed to the affine transformation """ + if not len(degrees): + degrees = self.degrees + if not len(translate): + translate = self.translate + if not len(scale_ranges): + scale_ranges = self.scale + if not len(shears): + shears = self.shear + if not len(img_size): + img_size = F._get_image_size(img) + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) if translate is not None: max_dx = float(translate[0] * img_size[0]) @@ -1429,7 +1664,7 @@ def get_params( return angle, translations, scale, shear - def forward(self, img): + def _call(self, img, angle, translations, scale, shear): """ img (PIL Image or Tensor): Image to be transformed. @@ -1443,11 +1678,15 @@ def forward(self, img): else: fill = [float(f) for f in fill] - img_size = F._get_image_size(img) - - ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - - return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) + return F.affine( + img, + angle, + translations, + scale, + shear, + interpolation=self.interpolation, + fill=fill, + ) def __repr__(self): s = '{name}(degrees={degrees}' @@ -1467,7 +1706,7 @@ def __repr__(self): return s.format(name=self.__class__.__name__, **d) -class Grayscale(torch.nn.Module): +class Grayscale(Transform): """Convert image to grayscale. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions @@ -1487,7 +1726,7 @@ def __init__(self, num_output_channels=1): super().__init__() self.num_output_channels = num_output_channels - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. @@ -1501,7 +1740,8 @@ def __repr__(self): return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) -class RandomGrayscale(torch.nn.Module): +class RandomGrayscale(Transform): + stochastic = True """Randomly convert image to grayscale with a probability of p (default 0.1). If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions @@ -1517,11 +1757,15 @@ class RandomGrayscale(torch.nn.Module): """ - def __init__(self, p=0.1): - super().__init__() + def __init__(self, p=0.1, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1) + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. @@ -1530,7 +1774,7 @@ def forward(self, img): PIL Image or Tensor: Randomly grayscaled image. """ num_output_channels = F._get_image_num_channels(img) - if torch.rand(1) < self.p: + if r < self.p: return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img @@ -1538,8 +1782,9 @@ def __repr__(self): return self.__class__.__name__ + '(p={0})'.format(self.p) -class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. +class RandomErasing(Transform): + stochastic = True + """Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 @@ -1565,8 +1810,16 @@ class RandomErasing(torch.nn.Module): >>> ]) """ - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): - super().__init__() + def __init__( + self, + p=0.5, + scale=(0.02, 0.33), + ratio=(0.3, 3.3), + value=0, + inplace=False, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) if not isinstance(value, (numbers.Number, str, tuple, list)): raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": @@ -1588,9 +1841,12 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace self.value = value self.inplace = inplace - @staticmethod - def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + def get_params_transform( + self, + img: Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + value: Optional[List[float]] = None, ) -> Tuple[int, int, int, int, Tensor]: """Get parameters for ``erase`` for a random erasing. @@ -1632,15 +1888,9 @@ def get_params( # Return original image return 0, 0, img_h, img_w, img - def forward(self, img): - """ - Args: - img (Tensor): Tensor image to be erased. - - Returns: - img (Tensor): Erased Tensor image. - """ - if torch.rand(1) < self.p: + def get_params(self, img): + r = torch.rand(1) + if r < self.p: # cast self.value to script acceptable type if isinstance(self.value, (int, float)): @@ -1658,7 +1908,21 @@ def forward(self, img): "{} (number of input channels)".format(img.shape[-3]) ) - x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + x, y, h, w, v = self.get_params_transform( + img, scale=self.scale, ratio=self.ratio, value=value + ) + return r, x, y, h, w, v + return r, None, None, None, None, None + + def _call(self, img, r, x, y, h, w, v): + """ + Args: + img (Tensor): Tensor image to be erased. + + Returns: + img (Tensor): Erased Tensor image. + """ + if r < self.p: return F.erase(img, x, y, h, w, v, self.inplace) return img @@ -1670,8 +1934,8 @@ def __repr__(self): s += 'inplace={})'.format(self.inplace) return self.__class__.__name__ + s - -class GaussianBlur(torch.nn.Module): +class GaussianBlur(Transform): + stochastic = True """Blurs image with randomly chosen Gaussian blur. If the image is torch Tensor, it is expected to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1688,8 +1952,8 @@ class GaussianBlur(torch.nn.Module): """ - def __init__(self, kernel_size, sigma=(0.1, 2.0)): - super().__init__() + def __init__(self, kernel_size, sigma=(0.1, 2.0), reset_auto=True): + super().__init__(reset_auto=reset_auto) self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") for ks in self.kernel_size: if ks <= 0 or ks % 2 == 0: @@ -1700,15 +1964,16 @@ def __init__(self, kernel_size, sigma=(0.1, 2.0)): raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: + if not 0.0 < sigma[0] <= sigma[1]: raise ValueError("sigma values should be positive and of the form (min, max).") else: raise ValueError("sigma should be a single number or a list/tuple with length 2.") self.sigma = sigma - @staticmethod - def get_params(sigma_min: float, sigma_max: float) -> float: + def get_params( + self, img, sigma_min: float = -1.0, sigma_max: float = -1.0 + ) -> float: """Choose sigma for random gaussian blurring. Args: @@ -1718,9 +1983,13 @@ def get_params(sigma_min: float, sigma_max: float) -> float: Returns: float: Standard deviation to be passed to calculate kernel for gaussian blurring. """ + if sigma_min == -1.0: + sigma_min = self.sigma[0] + if sigma_max == -1.0: + sigma_max = self.sigma[1] return torch.empty(1).uniform_(sigma_min, sigma_max).item() - def forward(self, img: Tensor) -> Tensor: + def _call(self, img: Tensor, sigma: float) -> Tensor: """ Args: img (PIL Image or Tensor): image to be blurred. @@ -1728,7 +1997,6 @@ def forward(self, img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Gaussian blurred image """ - sigma = self.get_params(self.sigma[0], self.sigma[1]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self): @@ -1769,7 +2037,8 @@ def _setup_angle(x, name, req_sizes=(2, )): return [float(d) for d in x] -class RandomInvert(torch.nn.Module): +class RandomInvert(Transform): + stochastic = True """Inverts the colors of the given image randomly with a given probability. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. @@ -1779,11 +2048,15 @@ class RandomInvert(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be inverted. @@ -1791,7 +2064,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly color inverted image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.invert(img) return img @@ -1799,7 +2072,8 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomPosterize(torch.nn.Module): +class RandomPosterize(Transform): + stochastic = True """Posterize the image randomly with a given probability by reducing the number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1810,12 +2084,16 @@ class RandomPosterize(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, bits, p=0.5): - super().__init__() + def __init__(self, bits, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.bits = bits self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be posterized. @@ -1823,7 +2101,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly posterized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.posterize(img, self.bits) return img @@ -1831,7 +2109,8 @@ def __repr__(self): return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) -class RandomSolarize(torch.nn.Module): +class RandomSolarize(Transform): + stochastic = True """Solarize the image randomly with a given probability by inverting all pixel values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. @@ -1842,12 +2121,16 @@ class RandomSolarize(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, threshold, p=0.5): - super().__init__() + def __init__(self, threshold, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.threshold = threshold self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be solarized. @@ -1855,7 +2138,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly solarized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.solarize(img, self.threshold) return img @@ -1863,7 +2146,7 @@ def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) -class RandomAdjustSharpness(torch.nn.Module): +class RandomAdjustSharpness(Transform): """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1874,12 +2157,16 @@ class RandomAdjustSharpness(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, sharpness_factor, p=0.5): - super().__init__() + def __init__(self, sharpness_factor, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.sharpness_factor = sharpness_factor self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be sharpened. @@ -1887,7 +2174,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly sharpened image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.adjust_sharpness(img, self.sharpness_factor) return img @@ -1895,7 +2182,7 @@ def __repr__(self): return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) -class RandomAutocontrast(torch.nn.Module): +class RandomAutocontrast(Transform): """Autocontrast the pixels of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1905,11 +2192,15 @@ class RandomAutocontrast(torch.nn.Module): p (float): probability of the image being autocontrasted. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be autocontrasted. @@ -1917,7 +2208,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly autocontrasted image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.autocontrast(img) return img @@ -1925,7 +2216,7 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomEqualize(torch.nn.Module): +class RandomEqualize(Transform): """Equalize the histogram of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1935,11 +2226,15 @@ class RandomEqualize(torch.nn.Module): p (float): probability of the image being equalized. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be equalized. @@ -1947,7 +2242,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly equalized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.equalize(img) return img