From cb19fc7527eeba8d7be433bfff2dc40904800301 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 10:09:28 +0100 Subject: [PATCH 1/9] fix grayscale to RGB for batches --- torchvision/prototype/transforms/functional/_meta.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 6ecb5aff257..5062c266959 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -58,7 +58,9 @@ def convert_bounding_box_format( def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: - return grayscale.expand(3, 1, 1) + repeats = [1] * grayscale.ndim + repeats[-3] = 3 + return grayscale.repeat(repeats) def convert_image_color_space_tensor( From d0e49be7bfa321ef79e0d227b40a39f20cb91c79 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 10:26:17 +0100 Subject: [PATCH 2/9] make unsupported types in auto augment a parameter --- torchvision/prototype/transforms/_auto_augment.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 892162fa296..01704757277 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type import PIL.Image import torch @@ -39,8 +39,10 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _check_unsupported(self, input: Any) -> None: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + def _check_unsupported( + self, input: Any, *, types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask) + ) -> None: + if isinstance(input, types): raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") def _extract_image( From e391eb1f1fe6ec42a4d1b7d71e19e05ea2923240 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 10:28:00 +0100 Subject: [PATCH 3/9] make auto augment kwargs explicit --- .../prototype/transforms/_auto_augment.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 01704757277..95b17e31480 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -223,8 +223,13 @@ class AutoAugment(_AutoAugmentBase): "Invert": (lambda num_bins, image_size: None, False), } - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) self.policy = policy self._policies = self._get_policies(policy) @@ -371,8 +376,16 @@ class RandAugment(_AutoAugmentBase): "Equalize": (lambda num_bins, image_size: None, False), } - def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + *, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins @@ -425,8 +438,14 @@ class TrivialAugmentWide(_AutoAugmentBase): "Equalize": (lambda num_bins, image_size: None, False), } - def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): - super().__init__(**kwargs) + def __init__( + self, + *, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ): + super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: @@ -482,9 +501,10 @@ def __init__( chain_depth: int = -1, alpha: float = 1.0, all_ops: bool = True, - **kwargs: Any, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, ) -> None: - super().__init__(**kwargs) + super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") From 79ac4e0a69eda014c89a19602359e2b597481c2f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 10:30:30 +0100 Subject: [PATCH 4/9] add missing error message --- torchvision/prototype/transforms/_meta.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 3675e1d8ada..09d2892769c 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -59,7 +59,10 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(input, output, color_space=self.color_space) elif isinstance(input, torch.Tensor): if self.old_color_space is None: - raise RuntimeError("") + raise RuntimeError( + f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` " + f"needs to be constructed with the `old_color_space=...` parameter." + ) return F.convert_image_color_space_tensor( input, old_color_space=self.old_color_space, new_color_space=self.color_space From 613fe8c545c6a1eb9d25e9f0ef79a3c52afde3e0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 15:04:10 +0100 Subject: [PATCH 5/9] add support for specifying probabilites on RandomChoice --- torchvision/prototype/transforms/_container.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index bd20d0c701a..1e93ec3b6c0 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional, List import torch @@ -37,14 +37,26 @@ def extra_repr(self) -> str: class RandomChoice(Transform): - def __init__(self, *transforms: Transform) -> None: + def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None: + if probabilities is None: + probabilities = [1] * len(transforms) + elif len(probabilities) != len(transforms): + raise ValueError( + f"The number of probabilities doesn't match the number of transforms: " + f"{len(probabilities)} != {len(transforms)}" + ) + super().__init__() + self.transforms = transforms for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) + total = sum(probabilities) + self.probabilities = [p / total for p in probabilities] + def forward(self, *inputs: Any) -> Any: - idx = int(torch.randint(len(self.transforms), size=())) + idx = int(torch.multinomial(torch.tensor(self.probabilities), 1)) transform = self.transforms[idx] return transform(*inputs) From 8d5774923a0b7fbf9251643873c953510dcc6f72 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 15:13:07 +0100 Subject: [PATCH 6/9] remove TODO for deprecating p on random transforms --- torchvision/prototype/transforms/_augment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 5862c6a06dc..7ff353a5274 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -33,7 +33,6 @@ def __init__( raise ValueError("Scale should be between 0 and 1") if p < 0 or p > 1: raise ValueError("Random erasing probability should be between 0 and 1") - # TODO: deprecate p in favor of wrapping the transform in a RandomApply self.p = p self.scale = scale self.ratio = ratio From 7f97132af0dadbfa8eafb1062cb750f2678ad4f5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 15:31:03 +0100 Subject: [PATCH 7/9] streamline sample type checking --- torchvision/prototype/transforms/_augment.py | 37 +++++++++++++------ .../prototype/transforms/_auto_augment.py | 14 ++++--- torchvision/prototype/transforms/_geometry.py | 22 +++++++---- torchvision/prototype/transforms/_utils.py | 14 ++++++- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 7ff353a5274..b9b1341f5b1 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image, get_image_dimensions +from ._utils import query_image, get_image_dimensions, has_all, has_any class RandomErasing(Transform): @@ -87,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.erase_image_tensor(input, **params) return features.Image.new_like(input, output) elif isinstance(input, torch.Tensor): @@ -98,8 +96,11 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input def forward(self, *inputs: Any) -> Any: - if torch.rand(1) >= self.p: - return inputs if len(inputs) > 1 else inputs[0] + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + elif torch.rand(1) >= self.p: + return sample return super().forward(*inputs) @@ -114,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.mixup_image_tensor(input, **params) return features.Image.new_like(input, output) elif isinstance(input, features.OneHotLabel): @@ -125,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + if not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) + class RandomCutmix(Transform): def __init__(self, *, alpha: float) -> None: @@ -156,9 +163,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.cutmix_image_tensor(input, box=params["box"]) return features.Image.new_like(input, output) elif isinstance(input, features.OneHotLabel): @@ -166,3 +171,11 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.OneHotLabel.new_like(input, output) else: return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + if not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 95b17e31480..63c22066e66 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -5,10 +5,10 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F -from torchvision.prototype.utils._internal import query_recursively +from torchvision.prototype.utils._internal import query_recursively, sequence_to_str from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import get_image_dimensions +from ._utils import get_image_dimensions, has_any K = TypeVar("K") V = TypeVar("V") @@ -40,10 +40,11 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: return key, dct[key] def _check_unsupported( - self, input: Any, *, types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask) + self, sample: Any, *, types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask) ) -> None: - if isinstance(input, types): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + if has_any(sample, *types): + names = sequence_to_str([t.__name__ for t in types], separate_last="and ") + raise TypeError(f"Inputs of type {names} are not supported by {type(self).__name__}()") def _extract_image( self, sample: Any @@ -54,9 +55,10 @@ def fn( if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return id, input - self._check_unsupported(input) return None + self._check_unsupported(sample) + images = list(query_recursively(fn, sample)) if not images: raise TypeError("Found no image in the sample.") diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c58f26a0e06..4bc3c14070f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image, get_image_dimensions +from ._utils import query_image, get_image_dimensions, has_any class HorizontalFlip(Transform): @@ -61,9 +61,7 @@ def __init__(self, output_size: List[int]): self.output_size = output_size def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.center_crop_image_tensor(input, self.output_size) return features.Image.new_like(input, output) elif isinstance(input, torch.Tensor): @@ -73,6 +71,12 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + class RandomResizedCrop(Transform): def __init__( @@ -147,9 +151,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = F.resized_crop_image_tensor( input, **params, size=list(self.size), interpolation=self.interpolation ) @@ -160,3 +162,9 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) else: return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index f5107908e41..74cbd84a64e 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, Type, Iterator import PIL.Image import torch @@ -34,3 +34,15 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im else: raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") return channels, height, width + + +def _extract_types(sample: Any) -> Iterator[Type]: + return query_recursively(lambda id, input: type(input), sample) + + +def has_any(sample: Any, *types: Type) -> bool: + return any(issubclass(type, types) for type in _extract_types(sample)) + + +def has_all(sample: Any, *types: Type) -> bool: + return not bool(set(types) - set(_extract_types(sample))) From 13c0d08503f36ddf3afdefa66908b52d46931170 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 15:59:40 +0100 Subject: [PATCH 8/9] address comments --- torchvision/prototype/transforms/_augment.py | 6 ++--- .../prototype/transforms/_auto_augment.py | 25 ++++++++----------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index b9b1341f5b1..5d4dc92a11b 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -102,7 +102,7 @@ def forward(self, *inputs: Any) -> Any: elif torch.rand(1) >= self.p: return sample - return super().forward(*inputs) + return super().forward(sample) class RandomMixup(Transform): @@ -128,7 +128,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - if not has_all(sample, features.Image, features.OneHotLabel): + elif not has_all(sample, features.Image, features.OneHotLabel): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") return super().forward(sample) @@ -176,6 +176,6 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - if not has_all(sample, features.Image, features.OneHotLabel): + elif not has_all(sample, features.Image, features.OneHotLabel): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") return super().forward(sample) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 63c22066e66..e2a5ad11908 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -5,10 +5,10 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F -from torchvision.prototype.utils._internal import query_recursively, sequence_to_str +from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import get_image_dimensions, has_any +from ._utils import get_image_dimensions K = TypeVar("K") V = TypeVar("V") @@ -39,25 +39,20 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _check_unsupported( - self, sample: Any, *, types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask) - ) -> None: - if has_any(sample, *types): - names = sequence_to_str([t.__name__ for t in types], separate_last="and ") - raise TypeError(f"Inputs of type {names} are not supported by {type(self).__name__}()") - def _extract_image( - self, sample: Any + self, + sample: Any, + unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: def fn( id: Tuple[Any, ...], input: Any ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return id, input - - return None - - self._check_unsupported(sample) + elif isinstance(input, unsupported_types): + raise TypeError(f"Inputs of type {type(input).__name__} are not supported by {type(self).__name__}()") + else: + return None images = list(query_recursively(fn, sample)) if not images: @@ -503,7 +498,7 @@ def __init__( chain_depth: int = -1, alpha: float = 1.0, all_ops: bool = True, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) From 6e06820823b419f58366713dbda0c124595cc01e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 16:10:29 +0100 Subject: [PATCH 9/9] split image_size into height and width in auto augment --- .../prototype/transforms/_auto_augment.py | 132 ++++++++++-------- 1 file changed, 72 insertions(+), 60 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index e2a5ad11908..cb4e5979102 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -199,25 +199,31 @@ def _apply_image_transform( class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), - "Invert": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Invert": (lambda num_bins, height, width: None, False), } def __init__( @@ -335,7 +341,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, (height, width)) + magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -352,25 +358,31 @@ def forward(self, *inputs: Any) -> Any: class RandAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } def __init__( @@ -397,7 +409,7 @@ def forward(self, *inputs: Any) -> Any: for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -414,25 +426,25 @@ def forward(self, *inputs: Any) -> Any: class TrivialAugmentWide(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } def __init__( @@ -454,7 +466,7 @@ def forward(self, *inputs: Any) -> Any: transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -468,27 +480,27 @@ def forward(self, *inputs: Any) -> Any: class AugMix(_AutoAugmentBase): _PARTIAL_AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } - _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = { + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = { **_PARTIAL_AUGMENTATION_SPACE, - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), } def __init__( @@ -550,7 +562,7 @@ def forward(self, *inputs: Any) -> Any: for _ in range(depth): transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) - magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width)) + magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) if signed and torch.rand(()) <= 0.5: