Skip to content

Commit 1a9ff0d

Browse files
authored
Port remaining transforms tests (#7954)
1 parent 997384c commit 1a9ff0d

9 files changed

+730
-1947
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -272,57 +272,6 @@ def test_common(self, transform, adapter, container_type, image_or_video, de_ser
272272
)
273273
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
274274

275-
@parametrize(
276-
[
277-
(
278-
transform,
279-
itertools.chain.from_iterable(
280-
fn(
281-
color_spaces=[
282-
"GRAY",
283-
"RGB",
284-
],
285-
dtypes=[torch.uint8],
286-
extra_dims=[(), (4,)],
287-
**(dict(num_frames=[3]) if fn is make_videos else dict()),
288-
)
289-
for fn in [
290-
make_images,
291-
make_vanilla_tensor_images,
292-
make_pil_images,
293-
make_videos,
294-
]
295-
),
296-
)
297-
for transform in (
298-
transforms.RandAugment(),
299-
transforms.TrivialAugmentWide(),
300-
transforms.AutoAugment(),
301-
transforms.AugMix(),
302-
)
303-
]
304-
)
305-
def test_auto_augment(self, transform, input):
306-
transform(input)
307-
308-
@parametrize(
309-
[
310-
(
311-
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
312-
itertools.chain.from_iterable(
313-
fn(color_spaces=["RGB"], dtypes=[torch.float32])
314-
for fn in [
315-
make_images,
316-
make_vanilla_tensor_images,
317-
make_videos,
318-
]
319-
),
320-
),
321-
]
322-
)
323-
def test_normalize(self, transform, input):
324-
transform(input)
325-
326275

327276
@pytest.mark.parametrize(
328277
"flat_inputs",
@@ -385,40 +334,6 @@ def was_applied(output, inpt):
385334
assert transform.was_applied(output, input)
386335

387336

388-
class TestElasticTransform:
389-
def test_assertions(self):
390-
391-
with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
392-
transforms.ElasticTransform({})
393-
394-
with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
395-
transforms.ElasticTransform([1.0, 2.0, 3.0])
396-
397-
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
398-
transforms.ElasticTransform(1.0, {})
399-
400-
with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
401-
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
402-
403-
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
404-
transforms.ElasticTransform(1.0, 2.0, fill="abc")
405-
406-
def test__get_params(self):
407-
alpha = 2.0
408-
sigma = 3.0
409-
transform = transforms.ElasticTransform(alpha, sigma)
410-
411-
h, w = size = (24, 32)
412-
image = make_image(size)
413-
414-
params = transform._get_params([image])
415-
416-
displacement = params["displacement"]
417-
assert displacement.shape == (1, h, w, 2)
418-
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
419-
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
420-
421-
422337
class TestTransform:
423338
@pytest.mark.parametrize(
424339
"inpt_type",
@@ -705,25 +620,6 @@ def test__get_params(self):
705620
assert min_size <= size < max_size
706621

707622

708-
class TestUniformTemporalSubsample:
709-
@pytest.mark.parametrize(
710-
"inpt",
711-
[
712-
torch.zeros(10, 3, 8, 8),
713-
torch.zeros(1, 10, 3, 8, 8),
714-
tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
715-
],
716-
)
717-
def test__transform(self, inpt):
718-
num_samples = 5
719-
transform = transforms.UniformTemporalSubsample(num_samples)
720-
721-
output = transform(inpt)
722-
assert type(output) is type(inpt)
723-
assert output.shape[-4] == num_samples
724-
assert output.dtype == inpt.dtype
725-
726-
727623
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
728624
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
729625
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))

test/test_transforms_v2_consistency.py

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -72,34 +72,6 @@ def __init__(
7272
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
7373

7474
CONSISTENCY_CONFIGS = [
75-
ConsistencyConfig(
76-
v2_transforms.Normalize,
77-
legacy_transforms.Normalize,
78-
[
79-
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
80-
],
81-
supports_pil=False,
82-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
83-
),
84-
ConsistencyConfig(
85-
v2_transforms.FiveCrop,
86-
legacy_transforms.FiveCrop,
87-
[
88-
ArgsKwargs(18),
89-
ArgsKwargs((18, 13)),
90-
],
91-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
92-
),
93-
ConsistencyConfig(
94-
v2_transforms.TenCrop,
95-
legacy_transforms.TenCrop,
96-
[
97-
ArgsKwargs(18),
98-
ArgsKwargs((18, 13)),
99-
ArgsKwargs(18, vertical_flip=True),
100-
],
101-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
102-
),
10375
*[
10476
ConsistencyConfig(
10577
v2_transforms.LinearTransformation,
@@ -147,65 +119,6 @@ def __init__(
147119
# images given that the transform does nothing but call it anyway.
148120
supports_pil=False,
149121
),
150-
ConsistencyConfig(
151-
v2_transforms.RandomEqualize,
152-
legacy_transforms.RandomEqualize,
153-
[
154-
ArgsKwargs(p=0),
155-
ArgsKwargs(p=1),
156-
],
157-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
158-
),
159-
ConsistencyConfig(
160-
v2_transforms.RandomInvert,
161-
legacy_transforms.RandomInvert,
162-
[
163-
ArgsKwargs(p=0),
164-
ArgsKwargs(p=1),
165-
],
166-
),
167-
ConsistencyConfig(
168-
v2_transforms.RandomPosterize,
169-
legacy_transforms.RandomPosterize,
170-
[
171-
ArgsKwargs(p=0, bits=5),
172-
ArgsKwargs(p=1, bits=1),
173-
ArgsKwargs(p=1, bits=3),
174-
],
175-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
176-
),
177-
ConsistencyConfig(
178-
v2_transforms.RandomSolarize,
179-
legacy_transforms.RandomSolarize,
180-
[
181-
ArgsKwargs(p=0, threshold=0.5),
182-
ArgsKwargs(p=1, threshold=0.3),
183-
ArgsKwargs(p=1, threshold=0.99),
184-
],
185-
),
186-
*[
187-
ConsistencyConfig(
188-
v2_transforms.RandomAutocontrast,
189-
legacy_transforms.RandomAutocontrast,
190-
[
191-
ArgsKwargs(p=0),
192-
ArgsKwargs(p=1),
193-
],
194-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
195-
closeness_kwargs=ckw,
196-
)
197-
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
198-
],
199-
ConsistencyConfig(
200-
v2_transforms.RandomAdjustSharpness,
201-
legacy_transforms.RandomAdjustSharpness,
202-
[
203-
ArgsKwargs(p=0, sharpness_factor=0.5),
204-
ArgsKwargs(p=1, sharpness_factor=0.2),
205-
ArgsKwargs(p=1, sharpness_factor=0.99),
206-
],
207-
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
208-
),
209122
ConsistencyConfig(
210123
v2_transforms.PILToTensor,
211124
legacy_transforms.PILToTensor,
@@ -230,22 +143,6 @@ def __init__(
230143
v2_transforms.RandomOrder,
231144
legacy_transforms.RandomOrder,
232145
),
233-
ConsistencyConfig(
234-
v2_transforms.AugMix,
235-
legacy_transforms.AugMix,
236-
),
237-
ConsistencyConfig(
238-
v2_transforms.AutoAugment,
239-
legacy_transforms.AutoAugment,
240-
),
241-
ConsistencyConfig(
242-
v2_transforms.RandAugment,
243-
legacy_transforms.RandAugment,
244-
),
245-
ConsistencyConfig(
246-
v2_transforms.TrivialAugmentWide,
247-
legacy_transforms.TrivialAugmentWide,
248-
),
249146
]
250147

251148

@@ -753,36 +650,9 @@ def test_common(self, t_ref, t, data_kwargs):
753650
(legacy_F.pil_to_tensor, {}),
754651
(legacy_F.convert_image_dtype, {}),
755652
(legacy_F.to_pil_image, {}),
756-
(legacy_F.normalize, {}),
757-
(legacy_F.resize, {"interpolation"}),
758-
(legacy_F.pad, {"padding", "fill"}),
759-
(legacy_F.crop, {}),
760-
(legacy_F.center_crop, {}),
761-
(legacy_F.resized_crop, {"interpolation"}),
762-
(legacy_F.hflip, {}),
763-
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
764-
(legacy_F.vflip, {}),
765-
(legacy_F.five_crop, {}),
766-
(legacy_F.ten_crop, {}),
767-
(legacy_F.adjust_brightness, {}),
768-
(legacy_F.adjust_contrast, {}),
769-
(legacy_F.adjust_saturation, {}),
770-
(legacy_F.adjust_hue, {}),
771-
(legacy_F.adjust_gamma, {}),
772-
(legacy_F.rotate, {"center", "fill", "interpolation"}),
773-
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
774653
(legacy_F.to_grayscale, {}),
775654
(legacy_F.rgb_to_grayscale, {}),
776655
(legacy_F.to_tensor, {}),
777-
(legacy_F.erase, {}),
778-
(legacy_F.gaussian_blur, {}),
779-
(legacy_F.invert, {}),
780-
(legacy_F.posterize, {}),
781-
(legacy_F.solarize, {}),
782-
(legacy_F.adjust_sharpness, {}),
783-
(legacy_F.autocontrast, {}),
784-
(legacy_F.equalize, {}),
785-
(legacy_F.elastic_transform, {"fill", "interpolation"}),
786656
],
787657
)
788658
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):

0 commit comments

Comments
 (0)