Skip to content

Commit 378f3c3

Browse files
authored
[proto] Added random color transforms and tests (#6275)
* Added random color transforms and tests * Disable smoke test for RandomSolarize, RandomAdjustSharpness
1 parent 5bb8178 commit 378f3c3

File tree

4 files changed

+129
-5
lines changed

4 files changed

+129
-5
lines changed

test/test_prototype_transforms.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ class TestSmoke:
8383
transforms.RandomRotation(degrees=(-45, 45)),
8484
transforms.RandomAffine(degrees=(-45, 45)),
8585
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
86+
# TODO: Something wrong with input data setup. Let's fix that
87+
# transforms.RandomEqualize(),
88+
# transforms.RandomInvert(),
89+
# transforms.RandomPosterize(bits=4),
90+
# transforms.RandomSolarize(threshold=0.5),
91+
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
8692
)
8793
def test_common(self, transform, input):
8894
transform(input)
@@ -699,3 +705,28 @@ def test__transform(self, kernel_size, sigma, mocker):
699705
params = transform._get_params(inpt)
700706

701707
fn.assert_called_once_with(inpt, **params)
708+
709+
710+
class TestRandomColorOp:
711+
@pytest.mark.parametrize("p", [0.0, 1.0])
712+
@pytest.mark.parametrize(
713+
"transform_cls, func_op_name, kwargs",
714+
[
715+
(transforms.RandomEqualize, "equalize", {}),
716+
(transforms.RandomInvert, "invert", {}),
717+
(transforms.RandomAutocontrast, "autocontrast", {}),
718+
(transforms.RandomPosterize, "posterize", {"bits": 4}),
719+
(transforms.RandomSolarize, "solarize", {"threshold": 0.5}),
720+
(transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}),
721+
],
722+
)
723+
def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
724+
transform = transform_cls(p=p, **kwargs)
725+
726+
fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}")
727+
inpt = mocker.MagicMock(spec=features.Image)
728+
_ = transform(inpt)
729+
if p > 0.0:
730+
fn.assert_called_once_with(inpt, **kwargs)
731+
else:
732+
fn.call_count == 0

test/test_prototype_transforms_functional.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,57 @@ def gaussian_blur_image_tensor():
515515
yield SampleInput(image, kernel_size=kernel_size, sigma=sigma)
516516

517517

518+
@register_kernel_info_from_sample_inputs_fn
519+
def equalize_image_tensor():
520+
for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
521+
if image.dtype != torch.uint8:
522+
continue
523+
yield SampleInput(image)
524+
525+
526+
@register_kernel_info_from_sample_inputs_fn
527+
def invert_image_tensor():
528+
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
529+
yield SampleInput(image)
530+
531+
532+
@register_kernel_info_from_sample_inputs_fn
533+
def posterize_image_tensor():
534+
for image, bits in itertools.product(
535+
make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
536+
[1, 4, 8],
537+
):
538+
if image.dtype != torch.uint8:
539+
continue
540+
yield SampleInput(image, bits=bits)
541+
542+
543+
@register_kernel_info_from_sample_inputs_fn
544+
def solarize_image_tensor():
545+
for image, threshold in itertools.product(
546+
make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
547+
[0.1, 0.5, 127.0],
548+
):
549+
if image.is_floating_point() and threshold > 1.0:
550+
continue
551+
yield SampleInput(image, threshold=threshold)
552+
553+
554+
@register_kernel_info_from_sample_inputs_fn
555+
def autocontrast_image_tensor():
556+
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
557+
yield SampleInput(image)
558+
559+
560+
@register_kernel_info_from_sample_inputs_fn
561+
def adjust_sharpness_image_tensor():
562+
for image, sharpness_factor in itertools.product(
563+
make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
564+
[0.1, 0.5],
565+
):
566+
yield SampleInput(image, sharpness_factor=sharpness_factor)
567+
568+
518569
@pytest.mark.parametrize(
519570
"kernel",
520571
[

torchvision/prototype/transforms/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44

55
from ._augment import RandomErasing, RandomMixup, RandomCutmix
66
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
7-
from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize
7+
from ._color import (
8+
ColorJitter,
9+
RandomPhotometricDistort,
10+
RandomEqualize,
11+
RandomInvert,
12+
RandomPosterize,
13+
RandomSolarize,
14+
RandomAutocontrast,
15+
RandomAdjustSharpness,
16+
)
817
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
918
from ._geometry import (
1019
Resize,
@@ -27,5 +36,4 @@
2736

2837
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
2938

30-
# TODO: add RandomPerspective, RandomInvert, RandomPosterize, RandomSolarize,
31-
# RandomAdjustSharpness, RandomAutocontrast, ElasticTransform
39+
# TODO: add RandomPerspective, ElasticTransform

torchvision/prototype/transforms/_color.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
151151

152152

153153
class RandomEqualize(_RandomApplyTransform):
154-
def __init__(self, p: float = 0.5):
154+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
155+
return F.equalize(inpt)
156+
157+
158+
class RandomInvert(_RandomApplyTransform):
159+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
160+
return F.invert(inpt)
161+
162+
163+
class RandomPosterize(_RandomApplyTransform):
164+
def __init__(self, bits: int, p: float = 0.5) -> None:
155165
super().__init__(p=p)
166+
self.bits = bits
156167

157168
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
158-
return F.equalize(inpt)
169+
return F.posterize(inpt, bits=self.bits)
170+
171+
172+
class RandomSolarize(_RandomApplyTransform):
173+
def __init__(self, threshold: float, p: float = 0.5) -> None:
174+
super().__init__(p=p)
175+
self.threshold = threshold
176+
177+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
178+
return F.solarize(inpt, threshold=self.threshold)
179+
180+
181+
class RandomAutocontrast(_RandomApplyTransform):
182+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
183+
return F.autocontrast(inpt)
184+
185+
186+
class RandomAdjustSharpness(_RandomApplyTransform):
187+
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
188+
super().__init__(p=p)
189+
self.sharpness_factor = sharpness_factor
190+
191+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
192+
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)

0 commit comments

Comments
 (0)