diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index b995101c3c7..efd54bf6c54 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,7 +3,7 @@ import random import warnings from collections.abc import Sequence -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, Union import torch from PIL import Image @@ -1018,8 +1018,9 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + @staticmethod @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + def _check_input(value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError("If {} is a single number, it must be non negative.".format(name)) @@ -1040,7 +1041,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs @staticmethod @torch.jit.unused - def get_params(brightness, contrast, saturation, hue): + def get_params(brightness: Union[float, tuple], contrast: Union[float, tuple], + saturation: Union[float, tuple], hue: Union[float, tuple]): """Get a randomized transform to be applied on image. Arguments are same as that of __init__. @@ -1052,18 +1054,26 @@ def get_params(brightness, contrast, saturation, hue): transforms = [] if brightness is not None: + if isinstance(brightness, float): + ColorJitter._check_input(brightness, 'brightness') brightness_factor = random.uniform(brightness[0], brightness[1]) transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) if contrast is not None: + if isinstance(contrast, float): + ColorJitter._check_input(contrast, 'contrast') contrast_factor = random.uniform(contrast[0], contrast[1]) transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) if saturation is not None: + if isinstance(saturation, float): + ColorJitter._check_input(saturation, 'saturation') saturation_factor = random.uniform(saturation[0], saturation[1]) transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) if hue is not None: + if isinstance(hue, float): + ColorJitter._check_input(hue, 'hue') hue_factor = random.uniform(hue[0], hue[1]) transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))