From faea1d15b61004ae0ef5642677bccfa86821d4f4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 17 Feb 2023 09:22:08 +0100 Subject: [PATCH 1/2] Improved consistency tests for AA with all policies --- test/test_prototype_transforms_consistency.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 9b3482f3f0a..50dc7cfe661 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1021,18 +1021,37 @@ def test_augmix(self, inpt, interpolation, mocker): PIL.Image.NEAREST, ], ) - def test_aa(self, inpt, interpolation): - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") + @pytest.mark.parametrize( + "policy", ["imagenet", "cifar10", "svhn"] + ) + def test_aa(self, inpt, interpolation, policy, mocker): + aa_policy = legacy_transforms.AutoAugmentPolicy(policy) t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - torch.manual_seed(12) - expected_output = t_ref(inpt) + le = len(t._AUGMENTATION_SPACE) - torch.manual_seed(12) - output = t(inpt) + policy_ids = [] + for i in range(le): + policy_ids.append(i) + policy_ids.append(i) + policy_ids = iter(policy_ids) - assert_equal(expected_output, output) + def torch_randint_side_effect(*args, **kwargs): + arg0 = args[0] + assert isinstance(arg0, int) + if arg0 > 2: + v = next(policy_ids) + return torch.tensor(v) + return torch.zeros(*args[1:], **kwargs) + + mocker.patch("torch.randint", side_effect=torch_randint_side_effect) + mocker.patch("torch.rand", side_effect=lambda *args, **kwargs: torch.zeros(*args, **kwargs)) + + for _ in range(le): + expected_output = t_ref(inpt) + output = t(inpt) + assert_equal(expected_output, output, atol=1, rtol=0) def import_transforms_from_references(reference): From e522f1f25ba34fdf2b8d51ecedf863100a492315 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 17 Feb 2023 11:47:47 +0100 Subject: [PATCH 2/2] Update test_prototype_transforms_consistency.py --- test/test_prototype_transforms_consistency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 50dc7cfe661..4d114239425 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1051,7 +1051,7 @@ def torch_randint_side_effect(*args, **kwargs): for _ in range(le): expected_output = t_ref(inpt) output = t(inpt) - assert_equal(expected_output, output, atol=1, rtol=0) + assert_equal(expected_output, output, atol=2, rtol=0) # set atol=2 as tests are flaky def import_transforms_from_references(reference):