Skip to content

Commit fcfd1b2

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port AA tests (#7927)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: matteobettini Differential Revision: D49600791 fbshipit-source-id: abf058e28a949717be7ad343e5417fca098d4078
1 parent f584964 commit fcfd1b2

File tree

2 files changed

+114
-278
lines changed

2 files changed

+114
-278
lines changed

test/test_transforms_v2_consistency.py

Lines changed: 0 additions & 275 deletions
Original file line numberDiff line numberDiff line change
@@ -705,281 +705,6 @@ def test_to_tensor(self):
705705
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
706706

707707

708-
class TestAATransforms:
709-
@pytest.mark.parametrize(
710-
"inpt",
711-
[
712-
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
713-
PIL.Image.new("RGB", (256, 256), 123),
714-
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
715-
],
716-
)
717-
@pytest.mark.parametrize(
718-
"interpolation",
719-
[
720-
v2_transforms.InterpolationMode.NEAREST,
721-
v2_transforms.InterpolationMode.BILINEAR,
722-
PIL.Image.NEAREST,
723-
],
724-
)
725-
def test_randaug(self, inpt, interpolation, mocker):
726-
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
727-
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
728-
729-
le = len(t._AUGMENTATION_SPACE)
730-
keys = list(t._AUGMENTATION_SPACE.keys())
731-
randint_values = []
732-
for i in range(le):
733-
# Stable API, op_index random call
734-
randint_values.append(i)
735-
# Stable API, if signed there is another random call
736-
if t._AUGMENTATION_SPACE[keys[i]][1]:
737-
randint_values.append(0)
738-
# New API, _get_random_item
739-
randint_values.append(i)
740-
randint_values = iter(randint_values)
741-
742-
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
743-
mocker.patch("torch.rand", return_value=1.0)
744-
745-
for i in range(le):
746-
expected_output = t_ref(inpt)
747-
output = t(inpt)
748-
749-
assert_close(expected_output, output, atol=1, rtol=0.1)
750-
751-
@pytest.mark.parametrize(
752-
"interpolation",
753-
[
754-
v2_transforms.InterpolationMode.NEAREST,
755-
v2_transforms.InterpolationMode.BILINEAR,
756-
],
757-
)
758-
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
759-
def test_randaug_jit(self, interpolation, fill):
760-
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
761-
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
762-
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
763-
764-
tt_ref = torch.jit.script(t_ref)
765-
tt = torch.jit.script(t)
766-
767-
torch.manual_seed(12)
768-
expected_output = tt_ref(inpt)
769-
770-
torch.manual_seed(12)
771-
scripted_output = tt(inpt)
772-
773-
assert_equal(scripted_output, expected_output)
774-
775-
@pytest.mark.parametrize(
776-
"inpt",
777-
[
778-
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
779-
PIL.Image.new("RGB", (256, 256), 123),
780-
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
781-
],
782-
)
783-
@pytest.mark.parametrize(
784-
"interpolation",
785-
[
786-
v2_transforms.InterpolationMode.NEAREST,
787-
v2_transforms.InterpolationMode.BILINEAR,
788-
PIL.Image.NEAREST,
789-
],
790-
)
791-
def test_trivial_aug(self, inpt, interpolation, mocker):
792-
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
793-
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
794-
795-
le = len(t._AUGMENTATION_SPACE)
796-
keys = list(t._AUGMENTATION_SPACE.keys())
797-
randint_values = []
798-
for i in range(le):
799-
# Stable API, op_index random call
800-
randint_values.append(i)
801-
key = keys[i]
802-
# Stable API, random magnitude
803-
aug_op = t._AUGMENTATION_SPACE[key]
804-
magnitudes = aug_op[0](2, 0, 0)
805-
if magnitudes is not None:
806-
randint_values.append(5)
807-
# Stable API, if signed there is another random call
808-
if aug_op[1]:
809-
randint_values.append(0)
810-
# New API, _get_random_item
811-
randint_values.append(i)
812-
# New API, random magnitude
813-
if magnitudes is not None:
814-
randint_values.append(5)
815-
816-
randint_values = iter(randint_values)
817-
818-
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
819-
mocker.patch("torch.rand", return_value=1.0)
820-
821-
for _ in range(le):
822-
expected_output = t_ref(inpt)
823-
output = t(inpt)
824-
825-
assert_close(expected_output, output, atol=1, rtol=0.1)
826-
827-
@pytest.mark.parametrize(
828-
"interpolation",
829-
[
830-
v2_transforms.InterpolationMode.NEAREST,
831-
v2_transforms.InterpolationMode.BILINEAR,
832-
],
833-
)
834-
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
835-
def test_trivial_aug_jit(self, interpolation, fill):
836-
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
837-
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
838-
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
839-
840-
tt_ref = torch.jit.script(t_ref)
841-
tt = torch.jit.script(t)
842-
843-
torch.manual_seed(12)
844-
expected_output = tt_ref(inpt)
845-
846-
torch.manual_seed(12)
847-
scripted_output = tt(inpt)
848-
849-
assert_equal(scripted_output, expected_output)
850-
851-
@pytest.mark.parametrize(
852-
"inpt",
853-
[
854-
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
855-
PIL.Image.new("RGB", (256, 256), 123),
856-
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
857-
],
858-
)
859-
@pytest.mark.parametrize(
860-
"interpolation",
861-
[
862-
v2_transforms.InterpolationMode.NEAREST,
863-
v2_transforms.InterpolationMode.BILINEAR,
864-
PIL.Image.NEAREST,
865-
],
866-
)
867-
def test_augmix(self, inpt, interpolation, mocker):
868-
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
869-
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
870-
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
871-
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
872-
873-
le = len(t._AUGMENTATION_SPACE)
874-
keys = list(t._AUGMENTATION_SPACE.keys())
875-
randint_values = []
876-
for i in range(le):
877-
# Stable API, op_index random call
878-
randint_values.append(i)
879-
key = keys[i]
880-
# Stable API, random magnitude
881-
aug_op = t._AUGMENTATION_SPACE[key]
882-
magnitudes = aug_op[0](2, 0, 0)
883-
if magnitudes is not None:
884-
randint_values.append(5)
885-
# Stable API, if signed there is another random call
886-
if aug_op[1]:
887-
randint_values.append(0)
888-
# New API, _get_random_item
889-
randint_values.append(i)
890-
# New API, random magnitude
891-
if magnitudes is not None:
892-
randint_values.append(5)
893-
894-
randint_values = iter(randint_values)
895-
896-
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
897-
mocker.patch("torch.rand", return_value=1.0)
898-
899-
expected_output = t_ref(inpt)
900-
output = t(inpt)
901-
902-
assert_equal(expected_output, output)
903-
904-
@pytest.mark.parametrize(
905-
"interpolation",
906-
[
907-
v2_transforms.InterpolationMode.NEAREST,
908-
v2_transforms.InterpolationMode.BILINEAR,
909-
],
910-
)
911-
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
912-
def test_augmix_jit(self, interpolation, fill):
913-
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
914-
915-
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
916-
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
917-
918-
tt_ref = torch.jit.script(t_ref)
919-
tt = torch.jit.script(t)
920-
921-
torch.manual_seed(12)
922-
expected_output = tt_ref(inpt)
923-
924-
torch.manual_seed(12)
925-
scripted_output = tt(inpt)
926-
927-
assert_equal(scripted_output, expected_output)
928-
929-
@pytest.mark.parametrize(
930-
"inpt",
931-
[
932-
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
933-
PIL.Image.new("RGB", (256, 256), 123),
934-
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
935-
],
936-
)
937-
@pytest.mark.parametrize(
938-
"interpolation",
939-
[
940-
v2_transforms.InterpolationMode.NEAREST,
941-
v2_transforms.InterpolationMode.BILINEAR,
942-
PIL.Image.NEAREST,
943-
],
944-
)
945-
def test_aa(self, inpt, interpolation):
946-
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
947-
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
948-
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
949-
950-
torch.manual_seed(12)
951-
expected_output = t_ref(inpt)
952-
953-
torch.manual_seed(12)
954-
output = t(inpt)
955-
956-
assert_equal(expected_output, output)
957-
958-
@pytest.mark.parametrize(
959-
"interpolation",
960-
[
961-
v2_transforms.InterpolationMode.NEAREST,
962-
v2_transforms.InterpolationMode.BILINEAR,
963-
],
964-
)
965-
def test_aa_jit(self, interpolation):
966-
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
967-
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
968-
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
969-
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
970-
971-
tt_ref = torch.jit.script(t_ref)
972-
tt = torch.jit.script(t)
973-
974-
torch.manual_seed(12)
975-
expected_output = tt_ref(inpt)
976-
977-
torch.manual_seed(12)
978-
scripted_output = tt(inpt)
979-
980-
assert_equal(scripted_output, expected_output)
981-
982-
983708
def import_transforms_from_references(reference):
984709
HERE = Path(__file__).parent
985710
PROJECT_ROOT = HERE.parent

0 commit comments

Comments
 (0)