Skip to content

Commit 9c4f738

Browse files
authored
Fixed issue with jitted AA transforms in v2 and added tests (#7839)
1 parent 37081ee commit 9c4f738

File tree

2 files changed

+108
-5
lines changed

2 files changed

+108
-5
lines changed

test/test_transforms_v2_consistency.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,29 @@ def test_randaug(self, inpt, interpolation, mocker):
927927

928928
assert_close(expected_output, output, atol=1, rtol=0.1)
929929

930+
@pytest.mark.parametrize(
931+
"interpolation",
932+
[
933+
v2_transforms.InterpolationMode.NEAREST,
934+
v2_transforms.InterpolationMode.BILINEAR,
935+
],
936+
)
937+
def test_randaug_jit(self, interpolation):
938+
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
939+
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
940+
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
941+
942+
tt_ref = torch.jit.script(t_ref)
943+
tt = torch.jit.script(t)
944+
945+
torch.manual_seed(12)
946+
expected_output = tt_ref(inpt)
947+
948+
torch.manual_seed(12)
949+
scripted_output = tt(inpt)
950+
951+
assert_equal(scripted_output, expected_output)
952+
930953
@pytest.mark.parametrize(
931954
"inpt",
932955
[
@@ -979,6 +1002,29 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
9791002

9801003
assert_close(expected_output, output, atol=1, rtol=0.1)
9811004

1005+
@pytest.mark.parametrize(
1006+
"interpolation",
1007+
[
1008+
v2_transforms.InterpolationMode.NEAREST,
1009+
v2_transforms.InterpolationMode.BILINEAR,
1010+
],
1011+
)
1012+
def test_trivial_aug_jit(self, interpolation):
1013+
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
1014+
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
1015+
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
1016+
1017+
tt_ref = torch.jit.script(t_ref)
1018+
tt = torch.jit.script(t)
1019+
1020+
torch.manual_seed(12)
1021+
expected_output = tt_ref(inpt)
1022+
1023+
torch.manual_seed(12)
1024+
scripted_output = tt(inpt)
1025+
1026+
assert_equal(scripted_output, expected_output)
1027+
9821028
@pytest.mark.parametrize(
9831029
"inpt",
9841030
[
@@ -1032,6 +1078,30 @@ def test_augmix(self, inpt, interpolation, mocker):
10321078

10331079
assert_equal(expected_output, output)
10341080

1081+
@pytest.mark.parametrize(
1082+
"interpolation",
1083+
[
1084+
v2_transforms.InterpolationMode.NEAREST,
1085+
v2_transforms.InterpolationMode.BILINEAR,
1086+
],
1087+
)
1088+
def test_augmix_jit(self, interpolation):
1089+
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
1090+
1091+
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
1092+
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
1093+
1094+
tt_ref = torch.jit.script(t_ref)
1095+
tt = torch.jit.script(t)
1096+
1097+
torch.manual_seed(12)
1098+
expected_output = tt_ref(inpt)
1099+
1100+
torch.manual_seed(12)
1101+
scripted_output = tt(inpt)
1102+
1103+
assert_equal(scripted_output, expected_output)
1104+
10351105
@pytest.mark.parametrize(
10361106
"inpt",
10371107
[
@@ -1061,6 +1131,30 @@ def test_aa(self, inpt, interpolation):
10611131

10621132
assert_equal(expected_output, output)
10631133

1134+
@pytest.mark.parametrize(
1135+
"interpolation",
1136+
[
1137+
v2_transforms.InterpolationMode.NEAREST,
1138+
v2_transforms.InterpolationMode.BILINEAR,
1139+
],
1140+
)
1141+
def test_aa_jit(self, interpolation):
1142+
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
1143+
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
1144+
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1145+
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1146+
1147+
tt_ref = torch.jit.script(t_ref)
1148+
tt = torch.jit.script(t)
1149+
1150+
torch.manual_seed(12)
1151+
expected_output = tt_ref(inpt)
1152+
1153+
torch.manual_seed(12)
1154+
scripted_output = tt(inpt)
1155+
1156+
assert_equal(scripted_output, expected_output)
1157+
10641158

10651159
def import_transforms_from_references(reference):
10661160
HERE = Path(__file__).parent

torchvision/transforms/v2/_auto_augment.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ def __init__(
2828
) -> None:
2929
super().__init__()
3030
self.interpolation = _check_interpolation(interpolation)
31-
self.fill = _setup_fill_arg(fill)
31+
self.fill = fill
32+
self._fill = _setup_fill_arg(fill)
33+
34+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
35+
params = super()._extract_params_for_v1_transform()
36+
37+
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
38+
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
39+
40+
return params
3241

3342
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
3443
keys = tuple(dct.keys())
@@ -335,7 +344,7 @@ def forward(self, *inputs: Any) -> Any:
335344
magnitude = 0.0
336345

337346
image_or_video = self._apply_image_or_video_transform(
338-
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
347+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
339348
)
340349

341350
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
@@ -419,7 +428,7 @@ def forward(self, *inputs: Any) -> Any:
419428
else:
420429
magnitude = 0.0
421430
image_or_video = self._apply_image_or_video_transform(
422-
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
431+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
423432
)
424433

425434
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
@@ -491,7 +500,7 @@ def forward(self, *inputs: Any) -> Any:
491500
magnitude = 0.0
492501

493502
image_or_video = self._apply_image_or_video_transform(
494-
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
503+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
495504
)
496505
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
497506

@@ -614,7 +623,7 @@ def forward(self, *inputs: Any) -> Any:
614623
magnitude = 0.0
615624

616625
aug = self._apply_image_or_video_transform(
617-
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
626+
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
618627
)
619628
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
620629
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

0 commit comments

Comments
 (0)