Skip to content

Commit ddfee23

Browse files
authored
port tests for container transforms (#8012)
1 parent 0040fe7 commit ddfee23

File tree

4 files changed

+112
-179
lines changed

4 files changed

+112
-179
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,35 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
122122
t(inpt)
123123

124124

125-
class TestContainers:
126-
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
127-
def test_assertions(self, transform_cls):
128-
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
129-
transform_cls(transforms.RandomCrop(28))
130-
131-
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
132-
@pytest.mark.parametrize(
133-
"trfms",
134-
[
135-
[transforms.Pad(2), transforms.RandomCrop(28)],
136-
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
137-
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
138-
],
139-
)
140-
def test_ctor(self, transform_cls, trfms):
141-
c = transform_cls(trfms)
142-
inpt = torch.rand(1, 3, 32, 32)
143-
output = c(inpt)
144-
assert isinstance(output, torch.Tensor)
145-
assert output.ndim == 4
146-
147-
148-
class TestRandomChoice:
149-
def test_assertions(self):
150-
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
151-
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])
152-
153-
154125
class TestRandomIoUCrop:
155126
@pytest.mark.parametrize("device", cpu_and_cuda())
156127
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])

test/test_transforms_v2_consistency.py

Lines changed: 1 addition & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import torch
1212
import torchvision.transforms.v2 as v2_transforms
1313
from common_utils import assert_close, assert_equal, set_rng_seed
14-
from torch import nn
1514
from torchvision import transforms as legacy_transforms, tv_tensors
16-
from torchvision._utils import sequence_to_str
1715

1816
from torchvision.transforms import functional as legacy_F
1917
from torchvision.transforms.v2 import functional as prototype_F
@@ -71,63 +69,7 @@ def __init__(
7169
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
7270
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
7371

74-
CONSISTENCY_CONFIGS = [
75-
ConsistencyConfig(
76-
v2_transforms.Compose,
77-
legacy_transforms.Compose,
78-
),
79-
ConsistencyConfig(
80-
v2_transforms.RandomApply,
81-
legacy_transforms.RandomApply,
82-
),
83-
ConsistencyConfig(
84-
v2_transforms.RandomChoice,
85-
legacy_transforms.RandomChoice,
86-
),
87-
ConsistencyConfig(
88-
v2_transforms.RandomOrder,
89-
legacy_transforms.RandomOrder,
90-
),
91-
]
92-
93-
94-
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
95-
def test_signature_consistency(config):
96-
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
97-
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
98-
99-
for param in config.removed_params:
100-
legacy_params.pop(param, None)
101-
102-
missing = legacy_params.keys() - prototype_params.keys()
103-
if missing:
104-
raise AssertionError(
105-
f"The prototype transform does not support the parameters "
106-
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
107-
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
108-
f"the `ConsistencyConfig`."
109-
)
110-
111-
extra = prototype_params.keys() - legacy_params.keys()
112-
extra_without_default = {
113-
param
114-
for param in extra
115-
if prototype_params[param].default is inspect.Parameter.empty
116-
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
117-
}
118-
if extra_without_default:
119-
raise AssertionError(
120-
f"The prototype transform requires the parameters "
121-
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
122-
f"not. Please add a default value."
123-
)
124-
125-
legacy_signature = list(legacy_params.keys())
126-
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
127-
# to the same number of parameters as the legacy one
128-
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
129-
130-
assert prototype_signature == legacy_signature
72+
CONSISTENCY_CONFIGS = []
13173

13274

13375
def check_call_consistency(
@@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs):
288230
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
289231

290232

291-
class TestContainerTransforms:
292-
"""
293-
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
294-
consistency automatically tests the wrapped transforms consistency.
295-
296-
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
297-
that were already tested for consistency above.
298-
"""
299-
300-
def test_compose(self):
301-
prototype_transform = v2_transforms.Compose(
302-
[
303-
v2_transforms.Resize(256),
304-
v2_transforms.CenterCrop(224),
305-
]
306-
)
307-
legacy_transform = legacy_transforms.Compose(
308-
[
309-
legacy_transforms.Resize(256),
310-
legacy_transforms.CenterCrop(224),
311-
]
312-
)
313-
314-
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
315-
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
316-
317-
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
318-
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
319-
def test_random_apply(self, p, sequence_type):
320-
prototype_transform = v2_transforms.RandomApply(
321-
sequence_type(
322-
[
323-
v2_transforms.Resize(256),
324-
v2_transforms.CenterCrop(224),
325-
]
326-
),
327-
p=p,
328-
)
329-
legacy_transform = legacy_transforms.RandomApply(
330-
sequence_type(
331-
[
332-
legacy_transforms.Resize(256),
333-
legacy_transforms.CenterCrop(224),
334-
]
335-
),
336-
p=p,
337-
)
338-
339-
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
340-
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
341-
342-
if sequence_type is nn.ModuleList:
343-
# quick and dirty test that it is jit-scriptable
344-
scripted = torch.jit.script(prototype_transform)
345-
scripted(torch.rand(1, 3, 300, 300))
346-
347-
# We can't test other values for `p` since the random parameter generation is different
348-
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
349-
def test_random_choice(self, probabilities):
350-
prototype_transform = v2_transforms.RandomChoice(
351-
[
352-
v2_transforms.Resize(256),
353-
legacy_transforms.CenterCrop(224),
354-
],
355-
p=probabilities,
356-
)
357-
legacy_transform = legacy_transforms.RandomChoice(
358-
[
359-
legacy_transforms.Resize(256),
360-
legacy_transforms.CenterCrop(224),
361-
],
362-
p=probabilities,
363-
)
364-
365-
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
366-
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
367-
368-
369233
class TestToTensorTransforms:
370234
def test_pil_to_tensor(self):
371235
prototype_transform = v2_transforms.PILToTensor()

test/test_transforms_v2_refactored.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
396396
if check_v1_compatibility:
397397
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
398398

399+
return output
400+
399401

400402
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
401403
def wrapper(input, *args, **kwargs):
@@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self):
17731775
transforms.RandomAffine(degrees=0, fill="fill")
17741776

17751777

1776-
class TestCompose:
1778+
class TestContainerTransforms:
17771779
class BuiltinTransform(transforms.Transform):
17781780
def _transform(self, inpt, params):
17791781
return inpt
@@ -1788,7 +1790,10 @@ def forward(self, image, label):
17881790
return image, label
17891791

17901792
@pytest.mark.parametrize(
1791-
"transform_clss",
1793+
"transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
1794+
)
1795+
@pytest.mark.parametrize(
1796+
"wrapped_transform_clss",
17921797
[
17931798
[BuiltinTransform],
17941799
[PackedInputTransform],
@@ -1803,12 +1808,12 @@ def forward(self, image, label):
18031808
],
18041809
)
18051810
@pytest.mark.parametrize("unpack", [True, False])
1806-
def test_packed_unpacked(self, transform_clss, unpack):
1807-
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
1808-
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
1811+
def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
1812+
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
1813+
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
18091814
assert not (needs_packed_inputs and needs_unpacked_inputs)
18101815

1811-
transform = transforms.Compose([cls() for cls in transform_clss])
1816+
transform = transform_cls([cls() for cls in wrapped_transform_clss])
18121817

18131818
image = make_image()
18141819
label = 3
@@ -1833,6 +1838,97 @@ def call_transform():
18331838
assert output[0] is image
18341839
assert output[1] is label
18351840

1841+
def test_compose(self):
1842+
transform = transforms.Compose(
1843+
[
1844+
transforms.RandomHorizontalFlip(p=1),
1845+
transforms.RandomVerticalFlip(p=1),
1846+
]
1847+
)
1848+
1849+
input = make_image()
1850+
1851+
actual = check_transform(transform, input)
1852+
expected = F.vertical_flip(F.horizontal_flip(input))
1853+
1854+
assert_equal(actual, expected)
1855+
1856+
@pytest.mark.parametrize("p", [0.0, 1.0])
1857+
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
1858+
def test_random_apply(self, p, sequence_type):
1859+
transform = transforms.RandomApply(
1860+
sequence_type(
1861+
[
1862+
transforms.RandomHorizontalFlip(p=1),
1863+
transforms.RandomVerticalFlip(p=1),
1864+
]
1865+
),
1866+
p=p,
1867+
)
1868+
1869+
# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
1870+
# check
1871+
input = make_image_tensor()
1872+
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))
1873+
1874+
if p == 1:
1875+
assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
1876+
else:
1877+
assert output is input
1878+
1879+
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
1880+
def test_random_choice(self, p):
1881+
transform = transforms.RandomChoice(
1882+
[
1883+
transforms.RandomHorizontalFlip(p=1),
1884+
transforms.RandomVerticalFlip(p=1),
1885+
],
1886+
p=p,
1887+
)
1888+
1889+
input = make_image()
1890+
output = check_transform(transform, input)
1891+
1892+
p_horz, p_vert = p
1893+
if p_horz:
1894+
assert_equal(output, F.horizontal_flip(input))
1895+
else:
1896+
assert_equal(output, F.vertical_flip(input))
1897+
1898+
def test_random_order(self):
1899+
transform = transforms.Compose(
1900+
[
1901+
transforms.RandomHorizontalFlip(p=1),
1902+
transforms.RandomVerticalFlip(p=1),
1903+
]
1904+
)
1905+
1906+
input = make_image()
1907+
1908+
actual = check_transform(transform, input)
1909+
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
1910+
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
1911+
# order, we can use a fixed order to compute the expected value.
1912+
expected = F.vertical_flip(F.horizontal_flip(input))
1913+
1914+
assert_equal(actual, expected)
1915+
1916+
def test_errors(self):
1917+
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
1918+
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
1919+
cls(lambda x: x)
1920+
1921+
with pytest.raises(ValueError, match="at least one transform"):
1922+
transforms.Compose([])
1923+
1924+
for p in [-1, 2]:
1925+
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
1926+
transforms.RandomApply([lambda x: x], p=p)
1927+
1928+
for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
1929+
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
1930+
transforms.RandomChoice(transforms_, p=p)
1931+
18361932

18371933
class TestToDtype:
18381934
@pytest.mark.parametrize(

torchvision/transforms/v2/_container.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
100100
return {"transforms": self.transforms, "p": self.p}
101101

102102
def forward(self, *inputs: Any) -> Any:
103-
sample = inputs if len(inputs) > 1 else inputs[0]
103+
needs_unpacking = len(inputs) > 1
104104

105105
if torch.rand(1) >= self.p:
106-
return sample
106+
return inputs if needs_unpacking else inputs[0]
107107

108108
for transform in self.transforms:
109-
sample = transform(sample)
110-
return sample
109+
outputs = transform(*inputs)
110+
inputs = outputs if needs_unpacking else (outputs,)
111+
return outputs
111112

112113
def extra_repr(self) -> str:
113114
format_string = []
@@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
173174
self.transforms = transforms
174175

175176
def forward(self, *inputs: Any) -> Any:
176-
sample = inputs if len(inputs) > 1 else inputs[0]
177+
needs_unpacking = len(inputs) > 1
177178
for idx in torch.randperm(len(self.transforms)):
178179
transform = self.transforms[idx]
179-
sample = transform(sample)
180-
return sample
180+
outputs = transform(*inputs)
181+
inputs = outputs if needs_unpacking else (outputs,)
182+
return outputs

0 commit comments

Comments
 (0)