Skip to content

Commit b828671

Browse files
authored
allow sequence fill for v2 AA scripted (#7919)
1 parent 96950a5 commit b828671

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

test/test_transforms_v2_consistency.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -755,10 +755,11 @@ def test_randaug(self, inpt, interpolation, mocker):
755755
v2_transforms.InterpolationMode.BILINEAR,
756756
],
757757
)
758-
def test_randaug_jit(self, interpolation):
758+
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
759+
def test_randaug_jit(self, interpolation, fill):
759760
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
760-
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
761-
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
761+
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
762+
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
762763

763764
tt_ref = torch.jit.script(t_ref)
764765
tt = torch.jit.script(t)
@@ -830,10 +831,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
830831
v2_transforms.InterpolationMode.BILINEAR,
831832
],
832833
)
833-
def test_trivial_aug_jit(self, interpolation):
834+
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
835+
def test_trivial_aug_jit(self, interpolation, fill):
834836
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
835-
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
836-
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
837+
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
838+
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
837839

838840
tt_ref = torch.jit.script(t_ref)
839841
tt = torch.jit.script(t)
@@ -906,11 +908,12 @@ def test_augmix(self, inpt, interpolation, mocker):
906908
v2_transforms.InterpolationMode.BILINEAR,
907909
],
908910
)
909-
def test_augmix_jit(self, interpolation):
911+
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
912+
def test_augmix_jit(self, interpolation, fill):
910913
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
911914

912-
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
913-
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
915+
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
916+
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
914917

915918
tt_ref = torch.jit.script(t_ref)
916919
tt = torch.jit.script(t)

torchvision/transforms/v2/_auto_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(
3333
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
3434
params = super()._extract_params_for_v1_transform()
3535

36-
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
37-
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
36+
if isinstance(params["fill"], dict):
37+
raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
3838

3939
return params
4040

0 commit comments

Comments
 (0)