diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 1f96caa247f..1f47eb2117f 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -755,10 +755,11 @@ def test_randaug(self, inpt, interpolation, mocker): v2_transforms.InterpolationMode.BILINEAR, ], ) - def test_randaug_jit(self, interpolation): + @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) + def test_randaug_jit(self, interpolation, fill): inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) - t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) + t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) + t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) tt_ref = torch.jit.script(t_ref) tt = torch.jit.script(t) @@ -830,10 +831,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker): v2_transforms.InterpolationMode.BILINEAR, ], ) - def test_trivial_aug_jit(self, interpolation): + @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) + def test_trivial_aug_jit(self, interpolation, fill): inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) - t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) + t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) + t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) tt_ref = torch.jit.script(t_ref) tt = torch.jit.script(t) @@ -906,11 +908,12 @@ def test_augmix(self, inpt, interpolation, mocker): v2_transforms.InterpolationMode.BILINEAR, ], ) - def test_augmix_jit(self, interpolation): + @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) + def test_augmix_jit(self, interpolation, fill): inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) + t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) tt_ref = torch.jit.script(t_ref) tt = torch.jit.script(t) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 4fec62f1b11..8ddd5aacdc3 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -33,8 +33,8 @@ def __init__( def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() - if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") + if isinstance(params["fill"], dict): + raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.") return params