From 749f8170bbab992adf21b9a1fc57a5b56c565182 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Sep 2023 11:52:05 +0200 Subject: [PATCH 01/14] port AA tests --- test/test_transforms_v2_consistency.py | 97 ---------------- test/test_transforms_v2_refactored.py | 149 +++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 97 deletions(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 1f47eb2117f..b09dc30ef8d 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -748,30 +748,6 @@ def test_randaug(self, inpt, interpolation, mocker): assert_close(expected_output, output, atol=1, rtol=0.1) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_randaug_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) - t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - @pytest.mark.parametrize( "inpt", [ @@ -824,30 +800,6 @@ def test_trivial_aug(self, inpt, interpolation, mocker): assert_close(expected_output, output, atol=1, rtol=0.1) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_trivial_aug_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) - t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - @pytest.mark.parametrize( "inpt", [ @@ -901,31 +853,6 @@ def test_augmix(self, inpt, interpolation, mocker): assert_equal(expected_output, output) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_augmix_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) - t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - @pytest.mark.parametrize( "inpt", [ @@ -955,30 +882,6 @@ def test_aa(self, inpt, interpolation): assert_equal(expected_output, output) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - def test_aa_jit(self, interpolation): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") - t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - def import_transforms_from_references(reference): HERE = Path(__file__).parent diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index f8a47c7cf39..02d2a60d4a8 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2901,3 +2901,152 @@ def test__get_params(self, sigma): else: assert sigma[0] <= params["sigma"][0] <= sigma[1] assert sigma[0] <= params["sigma"][1] <= sigma[1] + + +class TestAutoAugmentTransforms: + def check_transform(self, transform, input): + # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 and v2 + # outputs without complicated and brittle mocking and monkeypatching. Thus, we only run smoke tests for the + # eager and scripted v1 transform here in addition to the non-v1 compatibility tests of check_transform + check_transform(transform, input, check_v1_compatibility=False) + + if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): + return + + v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) + + v1_transform(input) + + if isinstance(input, PIL.Image.Image): + return + + _script(v1_transform)(input) + + @param_value_parametrization( + policy=list(transforms.AutoAugmentPolicy), + interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_auto_augment(self, param, value, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) + + def _reference_auto_augment_image_shear(self, image, *, transform_id, magnitude, interpolation): + def shear(pil_image): + if transform_id == "ShearX": + matrix = (1, magnitude, 0, 0, 1, 0) + elif transform_id == "ShearY": + matrix = (1, 0, 0, magnitude, 1, 0) + return pil_image.transform( + pil_image.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation] + ) + + if isinstance(image, PIL.Image.Image): + return shear(image) + else: + return F.to_image(shear(F.to_pil_image(image))) + + @pytest.mark.parametrize("transform_id", ["ShearX", "ShearY"]) + @pytest.mark.parametrize("magnitude", [0.0, 0.3, -0.2]) + @pytest.mark.parametrize( + "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] + ) + @pytest.mark.parametrize("input_type", ["Tensor", "PIL"]) + def test_auto_augment_image_shear_correctness(self, transform_id, magnitude, interpolation, input_type): + # We check that torchvision's implementation of shear is equivalent + # to official CIFAR10 autoaugment implementation: + # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290 + + image = make_image(dtype=torch.uint8, device="cpu") + if input_type == "PIL": + image = F.to_pil_image(image) + + actual = transforms.AutoAugment()._apply_image_or_video_transform( + image, + transform_id=transform_id, + magnitude=magnitude, + interpolation=interpolation, + fill={type(image): 0}, + ) + expected = self._reference_auto_augment_image_shear( + image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation + ) + + if input_type == "PIL": + assert_equal(F.to_image(actual), F.to_image(expected)) + else: + mae = (actual.float() - expected.float()).abs().mean() + assert mae < 12 if interpolation is transforms.InterpolationMode.NEAREST else 5 + + @param_value_parametrization( + num_ops=[1, 2, 3], + magnitude=[1, 10, 30], + interpolation=[ + transforms.InterpolationMode.NEAREST, + transforms.InterpolationMode.BILINEAR, + ], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_rand_augment(self, param, value, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + self.check_transform(transforms.RandAugment(**{param: value}), make_input()) + + @param_value_parametrization( + num_magnitude_bins=[1, 30, 50], + interpolation=[ + transforms.InterpolationMode.NEAREST, + transforms.InterpolationMode.BILINEAR, + ], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_trivial_augment_wide(self, param, value, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + self.check_transform(transforms.TrivialAugmentWide(**{param: value}), make_input()) + + @param_value_parametrization( + severity=[1, 5, 10], + mixture_width=[1, 2, 3], + chain_depth=[-1, 1, 3], + alpha=[0.5, 1.0, 2.0], + all_ops=[True, False], + interpolation=[ + transforms.InterpolationMode.NEAREST, + transforms.InterpolationMode.BILINEAR, + ], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_aug_mix(self, param, value, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + self.check_transform(transforms.AugMix(**{param: value}), make_input()) From 93cd592d5f3c5ed8ec8115749a590ad43dea8b17 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Sep 2023 14:54:36 +0200 Subject: [PATCH 02/14] random seed based on parametrization --- test/test_transforms_v2_refactored.py | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 02d2a60d4a8..90969c6583d 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2905,22 +2905,29 @@ def test__get_params(self, sigma): class TestAutoAugmentTransforms: def check_transform(self, transform, input): - # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 and v2 - # outputs without complicated and brittle mocking and monkeypatching. Thus, we only run smoke tests for the - # eager and scripted v1 transform here in addition to the non-v1 compatibility tests of check_transform - check_transform(transform, input, check_v1_compatibility=False) + v1_params = transform._extract_params_for_v1_transform() + v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) - if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): - return + with freeze_rng_state(): + # By default every test starts from the same random seed. This leads to minimal coverage of the sampled ops + # inside the transform. To avoid calling the transform multiple times, we build a reproducible random seed + # from the parametrization. Thus, we get better coverage without increasing the runtime. + torch.manual_seed(hash(pickle.dumps(v1_params))) - v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) + # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 + # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks + # here and run smoke tests below. + check_transform(transform, input, check_v1_compatibility=False) - v1_transform(input) + if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): + return - if isinstance(input, PIL.Image.Image): - return + v1_transform(input) + + if isinstance(input, PIL.Image.Image): + return - _script(v1_transform)(input) + _script(v1_transform)(input) @param_value_parametrization( policy=list(transforms.AutoAugmentPolicy), @@ -3008,7 +3015,7 @@ def test_transform_rand_augment(self, param, value, make_input, dtype, device): self.check_transform(transforms.RandAugment(**{param: value}), make_input()) @param_value_parametrization( - num_magnitude_bins=[1, 30, 50], + num_magnitude_bins=[2, 30, 50], interpolation=[ transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR, From fe1d477316141831ace7080755350b38599e3af9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Sep 2023 14:56:55 +0200 Subject: [PATCH 03/14] adapt fill --- test/test_transforms_v2_refactored.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 90969c6583d..0f8d171737d 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2944,6 +2944,9 @@ def test_transform_auto_augment(self, param, value, make_input, dtype, device): "will degenerate to that anyway." ) + if param == "fill": + value = adapt_fill(value, dtype=dtype) + self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) def _reference_auto_augment_image_shear(self, image, *, transform_id, magnitude, interpolation): @@ -3012,6 +3015,9 @@ def test_transform_rand_augment(self, param, value, make_input, dtype, device): "will degenerate to that anyway." ) + if param == "fill": + value = adapt_fill(value, dtype=dtype) + self.check_transform(transforms.RandAugment(**{param: value}), make_input()) @param_value_parametrization( @@ -3032,6 +3038,9 @@ def test_transform_trivial_augment_wide(self, param, value, make_input, dtype, d "will degenerate to that anyway." ) + if param == "fill": + value = adapt_fill(value, dtype=dtype) + self.check_transform(transforms.TrivialAugmentWide(**{param: value}), make_input()) @param_value_parametrization( @@ -3056,4 +3065,7 @@ def test_transform_aug_mix(self, param, value, make_input, dtype, device): "will degenerate to that anyway." ) + if param == "fill": + value = adapt_fill(value, dtype=dtype) + self.check_transform(transforms.AugMix(**{param: value}), make_input()) From 1d654a0f8365c9b99a7e21272ec2910cb1816e9a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Sep 2023 16:58:27 +0200 Subject: [PATCH 04/14] expand correctness tests --- test/test_transforms_v2_refactored.py | 68 ++++++++++++++++----------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 0f8d171737d..7418af42514 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2909,9 +2909,9 @@ def check_transform(self, transform, input): v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) with freeze_rng_state(): - # By default every test starts from the same random seed. This leads to minimal coverage of the sampled ops - # inside the transform. To avoid calling the transform multiple times, we build a reproducible random seed - # from the parametrization. Thus, we get better coverage without increasing the runtime. + # By default every test starts from the same random seed. This leads to minimal coverage of the sampling + # that happens inside forward(). To avoid calling the transform multiple times to achieve this coverage, we + # build a reproducible random seed from the parametrization. torch.manual_seed(hash(pickle.dumps(v1_params))) # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 @@ -2949,52 +2949,64 @@ def test_transform_auto_augment(self, param, value, make_input, dtype, device): self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) - def _reference_auto_augment_image_shear(self, image, *, transform_id, magnitude, interpolation): - def shear(pil_image): - if transform_id == "ShearX": - matrix = (1, magnitude, 0, 0, 1, 0) - elif transform_id == "ShearY": - matrix = (1, 0, 0, magnitude, 1, 0) - return pil_image.transform( - pil_image.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation] - ) + def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill): + if isinstance(image, PIL.Image.Image): + input = image + else: + input = F.to_pil_image(image) + + matrix = { + "ShearX": (1, magnitude, 0, 0, 1, 0), + "ShearY": (1, 0, 0, magnitude, 1, 0), + "TranslateX": (1, 0, -int(magnitude), 0, 1, 0), + "TranslateY": (1, 0, 0, 0, 1, -int(magnitude)), + }[transform_id] + + output = input.transform( + input.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation], fill=fill + ) if isinstance(image, PIL.Image.Image): - return shear(image) + return output else: - return F.to_image(shear(F.to_pil_image(image))) + return F.to_image(output) - @pytest.mark.parametrize("transform_id", ["ShearX", "ShearY"]) - @pytest.mark.parametrize("magnitude", [0.0, 0.3, -0.2]) + @pytest.mark.parametrize("transform_id", ["ShearX", "ShearY", "TranslateX", "TranslateY"]) + @pytest.mark.parametrize("magnitude", [0.3, -0.2, 0.0]) @pytest.mark.parametrize( "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] ) - @pytest.mark.parametrize("input_type", ["Tensor", "PIL"]) - def test_auto_augment_image_shear_correctness(self, transform_id, magnitude, interpolation, input_type): - # We check that torchvision's implementation of shear is equivalent - # to official CIFAR10 autoaugment implementation: - # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290 + @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) + @pytest.mark.parametrize("input_type", ["PIL"]) # "Tensor", + def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type): + # ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the + # reference implementation: + # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L362 + # All other ops are checked in their respective dedicated tests. image = make_image(dtype=torch.uint8, device="cpu") if input_type == "PIL": image = F.to_pil_image(image) + if "Translate" in transform_id: + # For TranslateX/Y magnitude is a value in pixels + magnitude *= min(F.get_size(image)) + actual = transforms.AutoAugment()._apply_image_or_video_transform( image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, - fill={type(image): 0}, + fill={type(image): fill}, ) - expected = self._reference_auto_augment_image_shear( - image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation + expected = self._reference_shear_translate( + image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill ) if input_type == "PIL": - assert_equal(F.to_image(actual), F.to_image(expected)) - else: - mae = (actual.float() - expected.float()).abs().mean() - assert mae < 12 if interpolation is transforms.InterpolationMode.NEAREST else 5 + actual, expected = F.to_image(actual), F.to_image(expected) + + assert_close(actual, expected, rtol=0, atol=1) @param_value_parametrization( num_ops=[1, 2, 3], From 6726e3ee508cba8ef16459ef784bd4e3514a4e58 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Sep 2023 17:08:46 +0200 Subject: [PATCH 05/14] fix --- test/test_transforms_v2_refactored.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 7418af42514..c13a2fc7039 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2910,8 +2910,8 @@ def check_transform(self, transform, input): with freeze_rng_state(): # By default every test starts from the same random seed. This leads to minimal coverage of the sampling - # that happens inside forward(). To avoid calling the transform multiple times to achieve this coverage, we - # build a reproducible random seed from the parametrization. + # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, + # we build a reproducible random seed from the parametrization. torch.manual_seed(hash(pickle.dumps(v1_params))) # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 @@ -2977,7 +2977,7 @@ def _reference_shear_translate(self, image, *, transform_id, magnitude, interpol "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] ) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) - @pytest.mark.parametrize("input_type", ["PIL"]) # "Tensor", + @pytest.mark.parametrize("input_type", ["Tensor", "PIL"]) def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type): # ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the # reference implementation: @@ -3006,7 +3006,11 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio if input_type == "PIL": actual, expected = F.to_image(actual), F.to_image(expected) - assert_close(actual, expected, rtol=0, atol=1) + if "Shear" in transform_id and input_type == "Tensor": + mae = (actual.float() - expected.float()).abs().mean() + assert mae < (12 if interpolation is transforms.InterpolationMode.NEAREST else 5) + else: + assert_close(actual, expected, rtol=0, atol=1) @param_value_parametrization( num_ops=[1, 2, 3], From 95cd2b32d88c7c1cc17dbcdd220f1dfa6e8790f2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:04:48 +0200 Subject: [PATCH 06/14] reorder --- test/test_transforms_v2_refactored.py | 90 +++++++++++++-------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c13a2fc7039..38bbb668356 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2904,51 +2904,6 @@ def test__get_params(self, sigma): class TestAutoAugmentTransforms: - def check_transform(self, transform, input): - v1_params = transform._extract_params_for_v1_transform() - v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) - - with freeze_rng_state(): - # By default every test starts from the same random seed. This leads to minimal coverage of the sampling - # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, - # we build a reproducible random seed from the parametrization. - torch.manual_seed(hash(pickle.dumps(v1_params))) - - # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 - # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks - # here and run smoke tests below. - check_transform(transform, input, check_v1_compatibility=False) - - if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): - return - - v1_transform(input) - - if isinstance(input, PIL.Image.Image): - return - - _script(v1_transform)(input) - - @param_value_parametrization( - policy=list(transforms.AutoAugmentPolicy), - interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], - fill=EXHAUSTIVE_TYPE_FILLS, - ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_auto_augment(self, param, value, make_input, dtype, device): - if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): - pytest.skip( - "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " - "will degenerate to that anyway." - ) - - if param == "fill": - value = adapt_fill(value, dtype=dtype) - - self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) - def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill): if isinstance(image, PIL.Image.Image): input = image @@ -3012,6 +2967,51 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio else: assert_close(actual, expected, rtol=0, atol=1) + def check_transform(self, transform, input): + v1_params = transform._extract_params_for_v1_transform() + v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) + + with freeze_rng_state(): + # By default every test starts from the same random seed. This leads to minimal coverage of the sampling + # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, + # we build a reproducible random seed from the parametrization. + torch.manual_seed(hash(pickle.dumps(v1_params))) + + # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 + # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks + # here and run smoke tests below. + check_transform(transform, input, check_v1_compatibility=False) + + if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): + return + + v1_transform(input) + + if isinstance(input, PIL.Image.Image): + return + + _script(v1_transform)(input) + + @param_value_parametrization( + policy=list(transforms.AutoAugmentPolicy), + interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_auto_augment(self, param, value, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + if param == "fill": + value = adapt_fill(value, dtype=dtype) + + self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) + @param_value_parametrization( num_ops=[1, 2, 3], magnitude=[1, 10, 30], From 2e3ba76565d38d71a8ef8bf20937a88c2dd0bd6e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:15:11 +0200 Subject: [PATCH 07/14] refactor forward tests --- test/test_transforms_v2_refactored.py | 124 ++++---------------------- 1 file changed, 18 insertions(+), 106 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 38bbb668356..87c54008dfb 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -232,7 +232,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version can be called without error.""" - if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): + if not (type(input) is torch.Tensor or isinstance(input, PIL.Image.Image)): return v1_transform_cls = transform._v1_transform_cls @@ -2967,121 +2967,33 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio else: assert_close(actual, expected, rtol=0, atol=1) - def check_transform(self, transform, input): - v1_params = transform._extract_params_for_v1_transform() - v1_transform = transform._v1_transform_cls(**transform._extract_params_for_v1_transform()) - - with freeze_rng_state(): - # By default every test starts from the same random seed. This leads to minimal coverage of the sampling - # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, - # we build a reproducible random seed from the parametrization. - torch.manual_seed(hash(pickle.dumps(v1_params))) - - # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 - # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks - # here and run smoke tests below. - check_transform(transform, input, check_v1_compatibility=False) - - if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): - return - - v1_transform(input) - - if isinstance(input, PIL.Image.Image): - return - - _script(v1_transform)(input) - - @param_value_parametrization( - policy=list(transforms.AutoAugmentPolicy), - interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], - fill=EXHAUSTIVE_TYPE_FILLS, - ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_auto_augment(self, param, value, make_input, dtype, device): - if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): - pytest.skip( - "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " - "will degenerate to that anyway." - ) - - if param == "fill": - value = adapt_fill(value, dtype=dtype) - - self.check_transform(transforms.AutoAugment(**{param: value}), make_input()) - - @param_value_parametrization( - num_ops=[1, 2, 3], - magnitude=[1, 10, 30], - interpolation=[ - transforms.InterpolationMode.NEAREST, - transforms.InterpolationMode.BILINEAR, - ], - fill=EXHAUSTIVE_TYPE_FILLS, - ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_rand_augment(self, param, value, make_input, dtype, device): - if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): - pytest.skip( - "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " - "will degenerate to that anyway." - ) - - if param == "fill": - value = adapt_fill(value, dtype=dtype) - - self.check_transform(transforms.RandAugment(**{param: value}), make_input()) - - @param_value_parametrization( - num_magnitude_bins=[2, 30, 50], - interpolation=[ - transforms.InterpolationMode.NEAREST, - transforms.InterpolationMode.BILINEAR, - ], - fill=EXHAUSTIVE_TYPE_FILLS, + @pytest.mark.parametrize( + "transform", + [transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()], ) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_trivial_augment_wide(self, param, value, make_input, dtype, device): + def test_forward(self, transform, make_input, dtype, device): if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): pytest.skip( "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " "will degenerate to that anyway." ) + input = make_input(dtype=dtype, device=device) - if param == "fill": - value = adapt_fill(value, dtype=dtype) - - self.check_transform(transforms.TrivialAugmentWide(**{param: value}), make_input()) - - @param_value_parametrization( - severity=[1, 5, 10], - mixture_width=[1, 2, 3], - chain_depth=[-1, 1, 3], - alpha=[0.5, 1.0, 2.0], - all_ops=[True, False], - interpolation=[ - transforms.InterpolationMode.NEAREST, - transforms.InterpolationMode.BILINEAR, - ], - fill=EXHAUSTIVE_TYPE_FILLS, - ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_aug_mix(self, param, value, make_input, dtype, device): - if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): - pytest.skip( - "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " - "will degenerate to that anyway." + with freeze_rng_state(): + # By default every test starts from the same random seed. This leads to minimal coverage of the sampling + # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, + # we build a reproducible random seed from the input type, dtype, and device. + torch.manual_seed( + hash((type(input), getattr(input, "dtype", torch.uint8), getattr(input, "device", "cpu"))) ) - if param == "fill": - value = adapt_fill(value, dtype=dtype) + # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 + # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks + # here and only check if we can script the v2 transform and subsequently call the result. + check_transform(transform, input, check_v1_compatibility=False) - self.check_transform(transforms.AugMix(**{param: value}), make_input()) + if type(input) is torch.Tensor and dtype is torch.uint8: + _script(transform)(input) From 2707869034008b8e0278cb28716d1a99542a6637 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:15:49 +0200 Subject: [PATCH 08/14] remove all old tests --- test/test_transforms_v2_consistency.py | 178 ------------------------- 1 file changed, 178 deletions(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index b09dc30ef8d..9badd8dbe1b 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -705,184 +705,6 @@ def test_to_tensor(self): assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy)) -class TestAATransforms: - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_randaug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) - t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - # Stable API, if signed there is another random call - if t._AUGMENTATION_SPACE[keys[i]][1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for i in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_close(expected_output, output, atol=1, rtol=0.1) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_trivial_aug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) - t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for _ in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_close(expected_output, output, atol=1, rtol=0.1) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_augmix(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) - t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t._sample_dirichlet = lambda t: t.softmax(dim=-1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - expected_output = t_ref(inpt) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_aa(self, inpt, interpolation): - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") - t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - - torch.manual_seed(12) - expected_output = t_ref(inpt) - - torch.manual_seed(12) - output = t(inpt) - - assert_equal(expected_output, output) - - def import_transforms_from_references(reference): HERE = Path(__file__).parent PROJECT_ROOT = HERE.parent From 6608a7e05d86f59e4550b429d5fb580b0fa87f47 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:19:24 +0200 Subject: [PATCH 09/14] add error tests --- test/test_transforms_v2_refactored.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 87c54008dfb..54858be2f37 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2997,3 +2997,12 @@ def test_forward(self, transform, make_input, dtype, device): if type(input) is torch.Tensor and dtype is torch.uint8: _script(transform)(input) + + def test_auto_augment_policy_error(self): + with pytest.raises(ValueError, match="provided policy"): + transforms.AutoAugment(policy=None) + + @pytest.mark.parametrize("severity", [0, 11]) + def test_aug_mix_severity_error(self, severity): + with pytest.raises(ValueError, match="severity must be between"): + transforms.AugMix(severity=severity) From d26eb6773f66451db5d900cf4c30da914a71dcac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:24:22 +0200 Subject: [PATCH 10/14] fix check_transform --- test/test_transforms_v2_refactored.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 45086b3b69f..85153cbfb40 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -250,6 +250,9 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): with freeze_rng_state(): output_v1 = v1_transform(input) + if all(isinstance(o, PIL.Image.Image) for o in [output_v2, output_v1]): + output_v2, output_v1 = [F.to_image(o) for o in [output_v2, output_v1]] + assert_close(output_v2, output_v1, rtol=rtol, atol=atol) if isinstance(input, PIL.Image.Image): @@ -2772,7 +2775,10 @@ def test_functional_signature(self, kernel, input_type): ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - check_transform(transforms.RandomErasing(p=1), make_input(device=device)) + input = make_input(device=device) + check_transform( + transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image) + ) def _reference_erase_image(self, image, *, i, j, h, w, v): mask = torch.zeros_like(image, dtype=torch.bool) From 027bb7a33143cbd8d3e20baecb82597daa09a5ea Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:36:37 +0200 Subject: [PATCH 11/14] Update test/test_transforms_v2_refactored.py Co-authored-by: Nicolas Hug --- test/test_transforms_v2_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 85153cbfb40..6468ef25a57 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2990,7 +2990,7 @@ def test_forward(self, transform, make_input, dtype, device): # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, # we build a reproducible random seed from the input type, dtype, and device. torch.manual_seed( - hash((type(input), getattr(input, "dtype", torch.uint8), getattr(input, "device", "cpu"))) + hash((make_input, dtype, device)) ) # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 From 574cdc0ee224d40ebc50267d41a2fb772387f6d5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 15:37:04 +0200 Subject: [PATCH 12/14] lint --- test/test_transforms_v2_refactored.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 6468ef25a57..557571525f4 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2989,9 +2989,7 @@ def test_forward(self, transform, make_input, dtype, device): # By default every test starts from the same random seed. This leads to minimal coverage of the sampling # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, # we build a reproducible random seed from the input type, dtype, and device. - torch.manual_seed( - hash((make_input, dtype, device)) - ) + torch.manual_seed(hash((make_input, dtype, device))) # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks From 26984bfb7395394775c6bcf90eb0b497688d45df Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 16:25:49 +0200 Subject: [PATCH 13/14] fix test name --- test/test_transforms_v2_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 557571525f4..45a21449e71 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2977,7 +2977,7 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_forward(self, transform, make_input, dtype, device): + def test_transform_smoke(self, transform, make_input, dtype, device): if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): pytest.skip( "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " From a4de6f081542584a16c2791e67d7e68c67afc664 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Sep 2023 16:33:15 +0200 Subject: [PATCH 14/14] address comments --- test/test_transforms_v2_refactored.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 45a21449e71..e978f57f257 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -250,10 +250,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): with freeze_rng_state(): output_v1 = v1_transform(input) - if all(isinstance(o, PIL.Image.Image) for o in [output_v2, output_v1]): - output_v2, output_v1 = [F.to_image(o) for o in [output_v2, output_v1]] - - assert_close(output_v2, output_v1, rtol=rtol, atol=atol) + assert_close(F.to_image(output_v2), F.to_image(output_v1), rtol=rtol, atol=atol) if isinstance(input, PIL.Image.Image): return @@ -2907,6 +2904,11 @@ def test__get_params(self, sigma): class TestAutoAugmentTransforms: + # These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling. + # It's typically very hard to test the effect on some parameters without heavy mocking logic. + # This class adds correctness tests for the kernels that are specific to those transforms. The rest of kernels, e.g. + # rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests. + def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill): if isinstance(image, PIL.Image.Image): input = image