Skip to content

Commit df6918c

Browse files
authored
[proto] Added few transforms tests, part 1 (#6262)
* Added supported/unsupported data checks in the tests for cutmix/mixup * Added RandomRotation, RandomAffine transforms tests * Added tests for RandomZoomOut, Pad * Update test_prototype_transforms.py
1 parent 9effc4c commit df6918c

File tree

3 files changed

+273
-8
lines changed

3 files changed

+273
-8
lines changed

test/test_prototype_transforms.py

Lines changed: 264 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
from test_prototype_transforms_functional import (
77
make_images,
88
make_bounding_boxes,
9+
make_bounding_box,
910
make_one_hot_labels,
11+
make_label,
12+
make_segmentation_mask,
1013
)
1114
from torchvision.prototype import transforms, features
12-
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
15+
from torchvision.transforms.functional import to_pil_image, pil_to_tensor, InterpolationMode
1316

1417

1518
def make_vanilla_tensor_images(*args, **kwargs):
@@ -106,6 +109,20 @@ def test_common(self, transform, input):
106109
def test_mixup_cutmix(self, transform, input):
107110
transform(input)
108111

112+
# add other data that should bypass and wont raise any error
113+
input_copy = dict(input)
114+
input_copy["path"] = "/path/to/somewhere"
115+
input_copy["num"] = 1234
116+
transform(input_copy)
117+
118+
# Check if we raise an error if sample contains bbox or mask or label
119+
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
120+
input_copy = dict(input)
121+
for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]:
122+
input_copy["unsupported"] = unsup_data
123+
with pytest.raises(TypeError, match=err_msg):
124+
transform(input_copy)
125+
109126
@parametrize(
110127
[
111128
(
@@ -303,3 +320,249 @@ def test_features_bounding_box(self, p):
303320
assert_equal(expected, actual)
304321
assert actual.format == expected.format
305322
assert actual.image_size == expected.image_size
323+
324+
325+
class TestPad:
326+
def test_assertions(self):
327+
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
328+
transforms.Pad("abc")
329+
330+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
331+
transforms.Pad([-0.7, 0, 0.7])
332+
333+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
334+
transforms.Pad(12, fill="abc")
335+
336+
with pytest.raises(ValueError, match="Padding mode should be either"):
337+
transforms.Pad(12, padding_mode="abc")
338+
339+
@pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]])
340+
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
341+
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
342+
def test__transform(self, padding, fill, padding_mode, mocker):
343+
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
344+
345+
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
346+
inpt = mocker.MagicMock(spec=torch.Tensor)
347+
_ = transform(inpt)
348+
349+
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
350+
351+
352+
class TestRandomZoomOut:
353+
def test_assertions(self):
354+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
355+
transforms.RandomZoomOut(fill="abc")
356+
357+
with pytest.raises(TypeError, match="should be a sequence of length"):
358+
transforms.RandomZoomOut(0, side_range=0)
359+
360+
with pytest.raises(ValueError, match="Invalid canvas side range"):
361+
transforms.RandomZoomOut(0, side_range=[4.0, 1.0])
362+
363+
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
364+
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
365+
def test__get_params(self, fill, side_range):
366+
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
367+
368+
image = features.Image(torch.rand(1, 3, 32, 32))
369+
c, h, w = image.shape[-3:]
370+
371+
params = transform._get_params(image)
372+
373+
assert params["fill"] == (fill if not isinstance(fill, int) else [fill] * c)
374+
assert len(params["padding"]) == 4
375+
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
376+
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
377+
assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w
378+
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
379+
380+
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
381+
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
382+
def test__transform(self, fill, side_range, mocker):
383+
image = features.Image(torch.rand(1, 3, 32, 32))
384+
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
385+
386+
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
387+
# vfdev-5, Feature Request: let's store params as Transform attribute
388+
# This could be also helpful for users
389+
torch.manual_seed(12)
390+
_ = transform(image)
391+
torch.manual_seed(12)
392+
torch.rand(1) # random apply changes random state
393+
params = transform._get_params(image)
394+
395+
fn.assert_called_once_with(image, **params)
396+
397+
398+
class TestRandomRotation:
399+
def test_assertions(self):
400+
with pytest.raises(ValueError, match="is a single number, it must be positive"):
401+
transforms.RandomRotation(-0.7)
402+
403+
for d in [[-0.7], [-0.7, 0, 0.7]]:
404+
with pytest.raises(ValueError, match="degrees should be a sequence of length 2"):
405+
transforms.RandomRotation(d)
406+
407+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
408+
transforms.RandomRotation(12, fill="abc")
409+
410+
with pytest.raises(TypeError, match="center should be a sequence of length"):
411+
transforms.RandomRotation(12, center=12)
412+
413+
with pytest.raises(ValueError, match="center should be a sequence of length"):
414+
transforms.RandomRotation(12, center=[1, 2, 3])
415+
416+
def test__get_params(self):
417+
angle_bound = 34
418+
transform = transforms.RandomRotation(angle_bound)
419+
420+
params = transform._get_params(None)
421+
assert -angle_bound <= params["angle"] <= angle_bound
422+
423+
angle_bounds = [12, 34]
424+
transform = transforms.RandomRotation(angle_bounds)
425+
426+
params = transform._get_params(None)
427+
assert angle_bounds[0] <= params["angle"] <= angle_bounds[1]
428+
429+
@pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)])
430+
@pytest.mark.parametrize("expand", [False, True])
431+
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
432+
@pytest.mark.parametrize("center", [None, [2.0, 3.0]])
433+
def test__transform(self, degrees, expand, fill, center, mocker):
434+
interpolation = InterpolationMode.BILINEAR
435+
transform = transforms.RandomRotation(
436+
degrees, interpolation=interpolation, expand=expand, fill=fill, center=center
437+
)
438+
439+
if isinstance(degrees, (tuple, list)):
440+
assert transform.degrees == [float(degrees[0]), float(degrees[1])]
441+
else:
442+
assert transform.degrees == [float(-degrees), float(degrees)]
443+
444+
fn = mocker.patch("torchvision.prototype.transforms.functional.rotate")
445+
inpt = mocker.MagicMock(spec=torch.Tensor)
446+
# vfdev-5, Feature Request: let's store params as Transform attribute
447+
# This could be also helpful for users
448+
torch.manual_seed(12)
449+
_ = transform(inpt)
450+
torch.manual_seed(12)
451+
params = transform._get_params(inpt)
452+
453+
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
454+
455+
456+
class TestRandomAffine:
457+
def test_assertions(self):
458+
with pytest.raises(ValueError, match="is a single number, it must be positive"):
459+
transforms.RandomAffine(-0.7)
460+
461+
for d in [[-0.7], [-0.7, 0, 0.7]]:
462+
with pytest.raises(ValueError, match="degrees should be a sequence of length 2"):
463+
transforms.RandomAffine(d)
464+
465+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
466+
transforms.RandomAffine(12, fill="abc")
467+
468+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
469+
transforms.RandomAffine(12, fill="abc")
470+
471+
for kwargs in [
472+
{"center": 12},
473+
{"translate": 12},
474+
{"scale": 12},
475+
]:
476+
with pytest.raises(TypeError, match="should be a sequence of length"):
477+
transforms.RandomAffine(12, **kwargs)
478+
479+
for kwargs in [{"center": [1, 2, 3]}, {"translate": [1, 2, 3]}, {"scale": [1, 2, 3]}]:
480+
with pytest.raises(ValueError, match="should be a sequence of length"):
481+
transforms.RandomAffine(12, **kwargs)
482+
483+
with pytest.raises(ValueError, match="translation values should be between 0 and 1"):
484+
transforms.RandomAffine(12, translate=[-1.0, 2.0])
485+
486+
with pytest.raises(ValueError, match="scale values should be positive"):
487+
transforms.RandomAffine(12, scale=[-1.0, 2.0])
488+
489+
with pytest.raises(ValueError, match="is a single number, it must be positive"):
490+
transforms.RandomAffine(12, shear=-10)
491+
492+
for s in [[-0.7], [-0.7, 0, 0.7]]:
493+
with pytest.raises(ValueError, match="shear should be a sequence of length 2"):
494+
transforms.RandomAffine(12, shear=s)
495+
496+
@pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)])
497+
@pytest.mark.parametrize("translate", [None, [0.1, 0.2]])
498+
@pytest.mark.parametrize("scale", [None, [0.7, 1.2]])
499+
@pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]])
500+
def test__get_params(self, degrees, translate, scale, shear):
501+
image = features.Image(torch.rand(1, 3, 32, 32))
502+
h, w = image.shape[-2:]
503+
504+
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
505+
params = transform._get_params(image)
506+
507+
if not isinstance(degrees, (list, tuple)):
508+
assert -degrees <= params["angle"] <= degrees
509+
else:
510+
assert degrees[0] <= params["angle"] <= degrees[1]
511+
512+
if translate is not None:
513+
assert -translate[0] * w <= params["translations"][0] <= translate[0] * w
514+
assert -translate[1] * h <= params["translations"][1] <= translate[1] * h
515+
else:
516+
assert params["translations"] == (0, 0)
517+
518+
if scale is not None:
519+
assert scale[0] <= params["scale"] <= scale[1]
520+
else:
521+
assert params["scale"] == 1.0
522+
523+
if shear is not None:
524+
if isinstance(shear, float):
525+
assert -shear <= params["shear"][0] <= shear
526+
assert params["shear"][1] == 0.0
527+
elif len(shear) == 2:
528+
assert shear[0] <= params["shear"][0] <= shear[1]
529+
assert params["shear"][1] == 0.0
530+
else:
531+
assert shear[0] <= params["shear"][0] <= shear[1]
532+
assert shear[2] <= params["shear"][1] <= shear[3]
533+
else:
534+
assert params["shear"] == (0, 0)
535+
536+
@pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)])
537+
@pytest.mark.parametrize("translate", [None, [0.1, 0.2]])
538+
@pytest.mark.parametrize("scale", [None, [0.7, 1.2]])
539+
@pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]])
540+
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
541+
@pytest.mark.parametrize("center", [None, [2.0, 3.0]])
542+
def test__transform(self, degrees, translate, scale, shear, fill, center, mocker):
543+
interpolation = InterpolationMode.BILINEAR
544+
transform = transforms.RandomAffine(
545+
degrees,
546+
translate=translate,
547+
scale=scale,
548+
shear=shear,
549+
interpolation=interpolation,
550+
fill=fill,
551+
center=center,
552+
)
553+
554+
if isinstance(degrees, (tuple, list)):
555+
assert transform.degrees == [float(degrees[0]), float(degrees[1])]
556+
else:
557+
assert transform.degrees == [float(-degrees), float(degrees)]
558+
559+
fn = mocker.patch("torchvision.prototype.transforms.functional.affine")
560+
inpt = features.Image(torch.rand(1, 3, 32, 32))
561+
# vfdev-5, Feature Request: let's store params as Transform attribute
562+
# This could be also helpful for users
563+
torch.manual_seed(12)
564+
_ = transform(inpt)
565+
torch.manual_seed(12)
566+
params = transform._get_params(inpt)
567+
568+
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)

torchvision/prototype/transforms/_geometry.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,16 @@ def __init__(
236236
if not isinstance(padding, (numbers.Number, tuple, list)):
237237
raise TypeError("Got inappropriate padding arg")
238238

239+
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
240+
raise ValueError(
241+
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
242+
)
243+
239244
_check_fill_arg(fill)
240245

241246
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
242247
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
243248

244-
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
245-
raise ValueError(
246-
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
247-
)
248-
249249
self.padding = padding
250250
self.fill = fill
251251
self.padding_mode = padding_mode
@@ -258,14 +258,16 @@ class RandomZoomOut(_RandomApplyTransform):
258258
def __init__(
259259
self,
260260
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
261-
side_range: Tuple[float, float] = (1.0, 4.0),
261+
side_range: Sequence[float] = (1.0, 4.0),
262262
p: float = 0.5,
263263
) -> None:
264264
super().__init__(p=p)
265265

266266
_check_fill_arg(fill)
267267
self.fill = fill
268268

269+
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
270+
269271
self.side_range = side_range
270272
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
271273
raise ValueError(f"Invalid canvas side range provided {side_range}.")

torchvision/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ def _check_sequence_input(x, name, req_sizes):
18551855
if not isinstance(x, Sequence):
18561856
raise TypeError(f"{name} should be a sequence of length {msg}.")
18571857
if len(x) not in req_sizes:
1858-
raise ValueError(f"{name} should be sequence of length {msg}.")
1858+
raise ValueError(f"{name} should be a sequence of length {msg}.")
18591859

18601860

18611861
def _setup_angle(x, name, req_sizes=(2,)):

0 commit comments

Comments
 (0)