Skip to content

port AA tests #7927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 4, 2023
275 changes: 0 additions & 275 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,281 +705,6 @@ def test_to_tensor(self):
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))


class TestAATransforms:
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_randaug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_trivial_aug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_augmix_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)

t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

torch.manual_seed(12)
expected_output = t_ref(inpt)

torch.manual_seed(12)
output = t(inpt)

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_aa_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)


def import_transforms_from_references(reference):
HERE = Path(__file__).parent
PROJECT_ROOT = HERE.parent
Expand Down
Loading