Skip to content

Commit 3a9aca1

Browse files
authored
Added RandomPerspective and tests (#6284)
- replaced real image creation by mocks for other tests
1 parent 378f3c3 commit 3a9aca1

File tree

5 files changed

+157
-20
lines changed

5 files changed

+157
-20
lines changed

test/test_prototype_transforms.py

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def test__transform(self, padding, fill, padding_mode, mocker):
350350
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
351351

352352
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
353-
inpt = mocker.MagicMock(spec=torch.Tensor)
353+
inpt = mocker.MagicMock(spec=features.Image)
354354
_ = transform(inpt)
355355

356356
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@@ -369,11 +369,12 @@ def test_assertions(self):
369369

370370
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
371371
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
372-
def test__get_params(self, fill, side_range):
372+
def test__get_params(self, fill, side_range, mocker):
373373
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
374374

375-
image = features.Image(torch.rand(1, 3, 32, 32))
376-
c, h, w = image.shape[-3:]
375+
image = mocker.MagicMock(spec=features.Image)
376+
c = image.num_channels = 3
377+
h, w = image.image_size = (24, 32)
377378

378379
params = transform._get_params(image)
379380

@@ -387,19 +388,22 @@ def test__get_params(self, fill, side_range):
387388
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
388389
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
389390
def test__transform(self, fill, side_range, mocker):
390-
image = features.Image(torch.rand(1, 3, 32, 32))
391+
inpt = mocker.MagicMock(spec=features.Image)
392+
inpt.num_channels = 3
393+
inpt.image_size = (24, 32)
394+
391395
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
392396

393397
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
394398
# vfdev-5, Feature Request: let's store params as Transform attribute
395399
# This could be also helpful for users
396400
torch.manual_seed(12)
397-
_ = transform(image)
401+
_ = transform(inpt)
398402
torch.manual_seed(12)
399403
torch.rand(1) # random apply changes random state
400-
params = transform._get_params(image)
404+
params = transform._get_params(inpt)
401405

402-
fn.assert_called_once_with(image, **params)
406+
fn.assert_called_once_with(inpt, **params)
403407

404408

405409
class TestRandomRotation:
@@ -449,7 +453,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
449453
assert transform.degrees == [float(-degrees), float(degrees)]
450454

451455
fn = mocker.patch("torchvision.prototype.transforms.functional.rotate")
452-
inpt = mocker.MagicMock(spec=torch.Tensor)
456+
inpt = mocker.MagicMock(spec=features.Image)
453457
# vfdev-5, Feature Request: let's store params as Transform attribute
454458
# This could be also helpful for users
455459
torch.manual_seed(12)
@@ -504,9 +508,11 @@ def test_assertions(self):
504508
@pytest.mark.parametrize("translate", [None, [0.1, 0.2]])
505509
@pytest.mark.parametrize("scale", [None, [0.7, 1.2]])
506510
@pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]])
507-
def test__get_params(self, degrees, translate, scale, shear):
508-
image = features.Image(torch.rand(1, 3, 32, 32))
509-
h, w = image.shape[-2:]
511+
def test__get_params(self, degrees, translate, scale, shear, mocker):
512+
image = mocker.MagicMock(spec=features.Image)
513+
image.num_channels = 3
514+
image.image_size = (24, 32)
515+
h, w = image.image_size
510516

511517
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
512518
params = transform._get_params(image)
@@ -564,7 +570,10 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
564570
assert transform.degrees == [float(-degrees), float(degrees)]
565571

566572
fn = mocker.patch("torchvision.prototype.transforms.functional.affine")
567-
inpt = features.Image(torch.rand(1, 3, 32, 32))
573+
inpt = mocker.MagicMock(spec=features.Image)
574+
inpt.num_channels = 3
575+
inpt.image_size = (24, 32)
576+
568577
# vfdev-5, Feature Request: let's store params as Transform attribute
569578
# This could be also helpful for users
570579
torch.manual_seed(12)
@@ -592,9 +601,11 @@ def test_assertions(self):
592601
with pytest.raises(ValueError, match="Padding mode should be either"):
593602
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
594603

595-
def test__get_params(self):
596-
image = features.Image(torch.rand(1, 3, 32, 32))
597-
h, w = image.shape[-2:]
604+
def test__get_params(self, mocker):
605+
image = mocker.MagicMock(spec=features.Image)
606+
image.num_channels = 3
607+
image.image_size = (24, 32)
608+
h, w = image.image_size
598609

599610
transform = transforms.RandomCrop([10, 10])
600611
params = transform._get_params(image)
@@ -614,7 +625,10 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
614625
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
615626
)
616627

617-
inpt = features.Image(torch.rand(1, 3, 32, 32))
628+
inpt = mocker.MagicMock(spec=features.Image)
629+
inpt.num_channels = 3
630+
inpt.image_size = (32, 32)
631+
618632
expected = mocker.MagicMock(spec=features.Image)
619633
expected.num_channels = 3
620634
if isinstance(padding, int):
@@ -696,7 +710,10 @@ def test__transform(self, kernel_size, sigma, mocker):
696710
assert transform.sigma == (sigma, sigma)
697711

698712
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
699-
inpt = features.Image(torch.rand(1, 3, 32, 32))
713+
inpt = mocker.MagicMock(spec=features.Image)
714+
inpt.num_channels = 3
715+
inpt.image_size = (24, 32)
716+
700717
# vfdev-5, Feature Request: let's store params as Transform attribute
701718
# This could be also helpful for users
702719
torch.manual_seed(12)
@@ -730,3 +747,58 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
730747
fn.assert_called_once_with(inpt, **kwargs)
731748
else:
732749
fn.call_count == 0
750+
751+
752+
class TestRandomPerspective:
753+
def test_assertions(self):
754+
with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"):
755+
transforms.RandomPerspective(distortion_scale=-1.0)
756+
757+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
758+
transforms.RandomPerspective(0.5, fill="abc")
759+
760+
def test__get_params(self, mocker):
761+
dscale = 0.5
762+
transform = transforms.RandomPerspective(dscale)
763+
image = mocker.MagicMock(spec=features.Image)
764+
image.num_channels = 3
765+
image.image_size = (24, 32)
766+
767+
params = transform._get_params(image)
768+
769+
h, w = image.image_size
770+
assert len(params["startpoints"]) == 4
771+
for x, y in params["startpoints"]:
772+
assert x in (0, w - 1)
773+
assert y in (0, h - 1)
774+
775+
assert len(params["endpoints"]) == 4
776+
for (x, y), name in zip(params["endpoints"], ["tl", "tr", "br", "bl"]):
777+
if "t" in name:
778+
assert 0 <= y <= int(dscale * h // 2), (x, y, name)
779+
if "b" in name:
780+
assert h - int(dscale * h // 2) - 1 <= y <= h, (x, y, name)
781+
if "l" in name:
782+
assert 0 <= x <= int(dscale * w // 2), (x, y, name)
783+
if "r" in name:
784+
assert w - int(dscale * w // 2) - 1 <= x <= w, (x, y, name)
785+
786+
@pytest.mark.parametrize("distortion_scale", [0.1, 0.7])
787+
def test__transform(self, distortion_scale, mocker):
788+
interpolation = InterpolationMode.BILINEAR
789+
fill = 12
790+
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
791+
792+
fn = mocker.patch("torchvision.prototype.transforms.functional.perspective")
793+
inpt = mocker.MagicMock(spec=features.Image)
794+
inpt.num_channels = 3
795+
inpt.image_size = (24, 32)
796+
# vfdev-5, Feature Request: let's store params as Transform attribute
797+
# This could be also helpful for users
798+
torch.manual_seed(12)
799+
_ = transform(inpt)
800+
torch.manual_seed(12)
801+
torch.rand(1) # random apply changes random state
802+
params = transform._get_params(inpt)
803+
804+
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)

test/test_prototype_transforms_functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,11 @@ def test_scriptable(kernel):
599599
and all(
600600
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
601601
)
602-
and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"}
602+
and name
603+
not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"}
603604
# We skip 'crop' due to missing 'height' and 'width'
604605
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
606+
# We skip 'perspective' as it requires different input args than perspective_image_tensor etc
605607
],
606608
)
607609
def test_functional_mid_level(func):

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
RandomZoomOut,
3030
RandomRotation,
3131
RandomAffine,
32+
RandomPerspective,
3233
)
3334
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
3435
from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda

torchvision/prototype/transforms/_geometry.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
292292
bottom = canvas_height - (top + orig_h)
293293
padding = [left, top, right, bottom]
294294

295+
# vfdev-5: Can we put that into pad_image_tensor ?
295296
fill = self.fill
296297
if not isinstance(fill, collections.abc.Sequence):
297298
fill = [fill] * orig_c
@@ -493,3 +494,60 @@ def forward(self, *inputs: Any) -> Any:
493494
flat_inputs, spec = tree_flatten(sample)
494495
out_flat_inputs = self._forward(flat_inputs)
495496
return tree_unflatten(out_flat_inputs, spec)
497+
498+
499+
class RandomPerspective(_RandomApplyTransform):
500+
def __init__(
501+
self,
502+
distortion_scale: float,
503+
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
504+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
505+
p: float = 0.5,
506+
) -> None:
507+
super().__init__(p=p)
508+
509+
_check_fill_arg(fill)
510+
if not (0 <= distortion_scale <= 1):
511+
raise ValueError("Argument distortion_scale value should be between 0 and 1")
512+
513+
self.distortion_scale = distortion_scale
514+
self.interpolation = interpolation
515+
self.fill = fill
516+
517+
def _get_params(self, sample: Any) -> Dict[str, Any]:
518+
# Get image size
519+
# TODO: make it work with bboxes and segm masks
520+
image = query_image(sample)
521+
_, height, width = get_image_dimensions(image)
522+
523+
distortion_scale = self.distortion_scale
524+
525+
half_height = height // 2
526+
half_width = width // 2
527+
topleft = [
528+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
529+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
530+
]
531+
topright = [
532+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
533+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
534+
]
535+
botright = [
536+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
537+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
538+
]
539+
botleft = [
540+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
541+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
542+
]
543+
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
544+
endpoints = [topleft, topright, botright, botleft]
545+
return dict(startpoints=startpoints, endpoints=endpoints)
546+
547+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
548+
return F.perspective(
549+
inpt,
550+
**params,
551+
fill=self.fill,
552+
interpolation=self.interpolation,
553+
)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_get_inverse_affine_matrix,
1212
InterpolationMode,
1313
_compute_output_size,
14+
_get_perspective_coeffs,
1415
)
1516

1617
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
@@ -765,10 +766,13 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl
765766

766767
def perspective(
767768
inpt: DType,
768-
perspective_coeffs: List[float],
769+
startpoints: List[List[int]],
770+
endpoints: List[List[int]],
769771
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
770772
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
771773
) -> DType:
774+
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
775+
772776
if isinstance(inpt, features._Feature):
773777
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
774778
elif isinstance(inpt, PIL.Image.Image):

0 commit comments

Comments
 (0)