diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d6987f6b71b..ca187aa5af5 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -83,6 +83,12 @@ class TestSmoke: transforms.RandomRotation(degrees=(-45, 45)), transforms.RandomAffine(degrees=(-45, 45)), transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), + # TODO: Something wrong with input data setup. Let's fix that + # transforms.RandomEqualize(), + # transforms.RandomInvert(), + # transforms.RandomPosterize(bits=4), + # transforms.RandomSolarize(threshold=0.5), + # transforms.RandomAdjustSharpness(sharpness_factor=0.5), ) def test_common(self, transform, input): transform(input) @@ -699,3 +705,28 @@ def test__transform(self, kernel_size, sigma, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params) + + +class TestRandomColorOp: + @pytest.mark.parametrize("p", [0.0, 1.0]) + @pytest.mark.parametrize( + "transform_cls, func_op_name, kwargs", + [ + (transforms.RandomEqualize, "equalize", {}), + (transforms.RandomInvert, "invert", {}), + (transforms.RandomAutocontrast, "autocontrast", {}), + (transforms.RandomPosterize, "posterize", {"bits": 4}), + (transforms.RandomSolarize, "solarize", {"threshold": 0.5}), + (transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}), + ], + ) + def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): + transform = transform_cls(p=p, **kwargs) + + fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") + inpt = mocker.MagicMock(spec=features.Image) + _ = transform(inpt) + if p > 0.0: + fn.assert_called_once_with(inpt, **kwargs) + else: + fn.call_count == 0 diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e39eb4b6632..e2d5ff2d24d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -515,6 +515,57 @@ def gaussian_blur_image_tensor(): yield SampleInput(image, kernel_size=kernel_size, sigma=sigma) +@register_kernel_info_from_sample_inputs_fn +def equalize_image_tensor(): + for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + if image.dtype != torch.uint8: + continue + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def invert_image_tensor(): + for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def posterize_image_tensor(): + for image, bits in itertools.product( + make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [1, 4, 8], + ): + if image.dtype != torch.uint8: + continue + yield SampleInput(image, bits=bits) + + +@register_kernel_info_from_sample_inputs_fn +def solarize_image_tensor(): + for image, threshold in itertools.product( + make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [0.1, 0.5, 127.0], + ): + if image.is_floating_point() and threshold > 1.0: + continue + yield SampleInput(image, threshold=threshold) + + +@register_kernel_info_from_sample_inputs_fn +def autocontrast_image_tensor(): + for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def adjust_sharpness_image_tensor(): + for image, sharpness_factor in itertools.product( + make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), + [0.1, 0.5], + ): + yield SampleInput(image, sharpness_factor=sharpness_factor) + + @pytest.mark.parametrize( "kernel", [ diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index f77b36d4643..42984847412 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -4,7 +4,16 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix -from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize +from ._color import ( + ColorJitter, + RandomPhotometricDistort, + RandomEqualize, + RandomInvert, + RandomPosterize, + RandomSolarize, + RandomAutocontrast, + RandomAdjustSharpness, +) from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, @@ -27,5 +36,4 @@ from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip -# TODO: add RandomPerspective, RandomInvert, RandomPosterize, RandomSolarize, -# RandomAdjustSharpness, RandomAutocontrast, ElasticTransform +# TODO: add RandomPerspective, ElasticTransform diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 85e22aaeb1a..7fd198161c9 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -151,8 +151,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomEqualize(_RandomApplyTransform): - def __init__(self, p: float = 0.5): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.equalize(inpt) + + +class RandomInvert(_RandomApplyTransform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.invert(inpt) + + +class RandomPosterize(_RandomApplyTransform): + def __init__(self, bits: int, p: float = 0.5) -> None: super().__init__(p=p) + self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.equalize(inpt) + return F.posterize(inpt, bits=self.bits) + + +class RandomSolarize(_RandomApplyTransform): + def __init__(self, threshold: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.threshold = threshold + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.solarize(inpt, threshold=self.threshold) + + +class RandomAutocontrast(_RandomApplyTransform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.autocontrast(inpt) + + +class RandomAdjustSharpness(_RandomApplyTransform): + def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.sharpness_factor = sharpness_factor + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)