diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 5862c6a06dc..5d4dc92a11b 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image, get_image_dimensions +from ._utils import query_image, get_image_dimensions, has_all, has_any class RandomErasing(Transform): @@ -33,7 +33,6 @@ def __init__( raise ValueError("Scale should be between 0 and 1") if p < 0 or p > 1: raise ValueError("Random erasing probability should be between 0 and 1") - # TODO: deprecate p in favor of wrapping the transform in a RandomApply self.p = p self.scale = scale self.ratio = ratio @@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.erase_image_tensor(input, **params) return features.Image.new_like(input, output) elif isinstance(input, torch.Tensor): @@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input def forward(self, *inputs: Any) -> Any: - if torch.rand(1) >= self.p: - return inputs if len(inputs) > 1 else inputs[0] + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + elif torch.rand(1) >= self.p: + return sample - return super().forward(*inputs) + return super().forward(sample) class RandomMixup(Transform): @@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.mixup_image_tensor(input, **params) return features.Image.new_like(input, output) elif isinstance(input, features.OneHotLabel): @@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + elif not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) + class RandomCutmix(Transform): def __init__(self, *, alpha: float) -> None: @@ -157,9 +163,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.cutmix_image_tensor(input, box=params["box"]) return features.Image.new_like(input, output) elif isinstance(input, features.OneHotLabel): @@ -167,3 +171,11 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.OneHotLabel.new_like(input, output) else: return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + elif not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 892162fa296..cb4e5979102 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type import PIL.Image import torch @@ -39,21 +39,20 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _check_unsupported(self, input: Any) -> None: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - def _extract_image( - self, sample: Any + self, + sample: Any, + unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: def fn( id: Tuple[Any, ...], input: Any ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return id, input - - self._check_unsupported(input) - return None + elif isinstance(input, unsupported_types): + raise TypeError(f"Inputs of type {type(input).__name__} are not supported by {type(self).__name__}()") + else: + return None images = list(query_recursively(fn, sample)) if not images: @@ -200,29 +199,40 @@ def _apply_image_transform( class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), - "Invert": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Invert": (lambda num_bins, height, width: None, False), } - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) self.policy = policy self._policies = self._get_policies(policy) @@ -331,7 +341,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, (height, width)) + magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -348,29 +358,43 @@ def forward(self, *inputs: Any) -> Any: class RandAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } - def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + *, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins @@ -385,7 +409,7 @@ def forward(self, *inputs: Any) -> Any: for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -402,29 +426,35 @@ def forward(self, *inputs: Any) -> Any: class TrivialAugmentWide(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } - def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): - super().__init__(**kwargs) + def __init__( + self, + *, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ): + super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: @@ -436,7 +466,7 @@ def forward(self, *inputs: Any) -> Any: transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -450,27 +480,27 @@ def forward(self, *inputs: Any) -> Any: class AugMix(_AutoAugmentBase): _PARTIAL_AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } - _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = { + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = { **_PARTIAL_AUGMENTATION_SPACE, - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), } def __init__( @@ -480,9 +510,10 @@ def __init__( chain_depth: int = -1, alpha: float = 1.0, all_ops: bool = True, - **kwargs: Any, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, ) -> None: - super().__init__(**kwargs) + super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") @@ -531,7 +562,7 @@ def forward(self, *inputs: Any) -> Any: for _ in range(depth): transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) - magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width)) + magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) if signed and torch.rand(()) <= 0.5: diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index bd20d0c701a..1e93ec3b6c0 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional, List import torch @@ -37,14 +37,26 @@ def extra_repr(self) -> str: class RandomChoice(Transform): - def __init__(self, *transforms: Transform) -> None: + def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None: + if probabilities is None: + probabilities = [1] * len(transforms) + elif len(probabilities) != len(transforms): + raise ValueError( + f"The number of probabilities doesn't match the number of transforms: " + f"{len(probabilities)} != {len(transforms)}" + ) + super().__init__() + self.transforms = transforms for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) + total = sum(probabilities) + self.probabilities = [p / total for p in probabilities] + def forward(self, *inputs: Any) -> Any: - idx = int(torch.randint(len(self.transforms), size=())) + idx = int(torch.multinomial(torch.tensor(self.probabilities), 1)) transform = self.transforms[idx] return transform(*inputs) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c58f26a0e06..4bc3c14070f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image, get_image_dimensions +from ._utils import query_image, get_image_dimensions, has_any class HorizontalFlip(Transform): @@ -61,9 +61,7 @@ def __init__(self, output_size: List[int]): self.output_size = output_size def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.center_crop_image_tensor(input, self.output_size) return features.Image.new_like(input, output) elif isinstance(input, torch.Tensor): @@ -73,6 +71,12 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + class RandomResizedCrop(Transform): def __init__( @@ -147,9 +151,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.resized_crop_image_tensor( input, **params, size=list(self.size), interpolation=self.interpolation ) @@ -160,3 +162,9 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) else: return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 3675e1d8ada..09d2892769c 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -59,7 +59,10 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(input, output, color_space=self.color_space) elif isinstance(input, torch.Tensor): if self.old_color_space is None: - raise RuntimeError("") + raise RuntimeError( + f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` " + f"needs to be constructed with the `old_color_space=...` parameter." + ) return F.convert_image_color_space_tensor( input, old_color_space=self.old_color_space, new_color_space=self.color_space diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index f5107908e41..74cbd84a64e 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, Type, Iterator import PIL.Image import torch @@ -34,3 +34,15 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im else: raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") return channels, height, width + + +def _extract_types(sample: Any) -> Iterator[Type]: + return query_recursively(lambda id, input: type(input), sample) + + +def has_any(sample: Any, *types: Type) -> bool: + return any(issubclass(type, types) for type in _extract_types(sample)) + + +def has_all(sample: Any, *types: Type) -> bool: + return not bool(set(types) - set(_extract_types(sample))) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 6ecb5aff257..5062c266959 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -58,7 +58,9 @@ def convert_bounding_box_format( def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: - return grayscale.expand(3, 1, 1) + repeats = [1] * grayscale.ndim + repeats[-3] = 3 + return grayscale.repeat(repeats) def convert_image_color_space_tensor(