diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 9b3482f3f0a..4d114239425 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1021,18 +1021,37 @@ def test_augmix(self, inpt, interpolation, mocker): PIL.Image.NEAREST, ], ) - def test_aa(self, inpt, interpolation): - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") + @pytest.mark.parametrize( + "policy", ["imagenet", "cifar10", "svhn"] + ) + def test_aa(self, inpt, interpolation, policy, mocker): + aa_policy = legacy_transforms.AutoAugmentPolicy(policy) t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - torch.manual_seed(12) - expected_output = t_ref(inpt) + le = len(t._AUGMENTATION_SPACE) - torch.manual_seed(12) - output = t(inpt) + policy_ids = [] + for i in range(le): + policy_ids.append(i) + policy_ids.append(i) + policy_ids = iter(policy_ids) - assert_equal(expected_output, output) + def torch_randint_side_effect(*args, **kwargs): + arg0 = args[0] + assert isinstance(arg0, int) + if arg0 > 2: + v = next(policy_ids) + return torch.tensor(v) + return torch.zeros(*args[1:], **kwargs) + + mocker.patch("torch.randint", side_effect=torch_randint_side_effect) + mocker.patch("torch.rand", side_effect=lambda *args, **kwargs: torch.zeros(*args, **kwargs)) + + for _ in range(le): + expected_output = t_ref(inpt) + output = t(inpt) + assert_equal(expected_output, output, atol=2, rtol=0) # set atol=2 as tests are flaky def import_transforms_from_references(reference):