Skip to content

Commit 7bc66fe

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] add segmentation reference consistency tests (#6591)
Summary: * add segmentation reference consistency tests * fall back to smoke tests for resize * add test for RandomCrop Reviewed By: YosuaMichael Differential Revision: D39885419 fbshipit-source-id: 12692442af350ca9b97e8764ed429d2d56e0682a Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 0096bd2 commit 7bc66fe

File tree

2 files changed

+177
-8
lines changed

2 files changed

+177
-8
lines changed

test/prototype_common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def __init__(
6969

7070
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
7171
actual, expected = [
72-
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
72+
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
73+
for input in [actual, expected]
7374
]
7475
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
7576
# image to a tensor adds a singleton leading dimension.

test/test_prototype_transforms_consistency.py

Lines changed: 175 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import enum
22
import inspect
3+
import random
4+
from collections import defaultdict
35
from importlib.machinery import SourceFileLoader
46
from pathlib import Path
57

@@ -16,13 +18,15 @@
1618
make_image,
1719
make_images,
1820
make_label,
21+
make_segmentation_mask,
1922
)
2023
from torchvision import transforms as legacy_transforms
2124
from torchvision._utils import sequence_to_str
2225
from torchvision.prototype import features, transforms as prototype_transforms
26+
from torchvision.prototype.transforms import functional as F
27+
from torchvision.prototype.transforms._utils import query_chw
2328
from torchvision.prototype.transforms.functional import to_image_pil
2429

25-
2630
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
2731

2832

@@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation):
852856
assert_equal(expected_output, output)
853857

854858

855-
# Import reference detection transforms here for consistency checks
856-
# torchvision/references/detection/transforms.py
857-
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
858-
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
859+
def import_transforms_from_references(reference):
860+
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
861+
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
862+
863+
864+
det_transforms = import_transforms_from_references("detection")
859865

860866

861867
class TestRefDetTransforms:
@@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True):
873879

874880
yield (pil_image, target)
875881

876-
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
882+
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
877883
target = {
878884
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
879885
"labels": make_label(extra_dims=(num_objects,), categories=80),
@@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True):
883889

884890
yield (tensor_image, target)
885891

886-
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
892+
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
887893
target = {
888894
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
889895
"labels": make_label(extra_dims=(num_objects,), categories=80),
@@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs):
927933
expected_output = t_ref(*dp)
928934

929935
assert_equal(expected_output, output)
936+
937+
938+
seg_transforms = import_transforms_from_references("segmentation")
939+
940+
941+
# We need this transform for two reasons:
942+
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
943+
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
944+
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
945+
class PadIfSmaller(prototype_transforms.Transform):
946+
def __init__(self, size, fill=0):
947+
super().__init__()
948+
self.size = size
949+
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)
950+
951+
def _get_params(self, sample):
952+
_, height, width = query_chw(sample)
953+
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
954+
needs_padding = any(padding)
955+
return dict(padding=padding, needs_padding=needs_padding)
956+
957+
def _transform(self, inpt, params):
958+
if not params["needs_padding"]:
959+
return inpt
960+
961+
fill = self.fill[type(inpt)]
962+
fill = F._geometry._convert_fill_arg(fill)
963+
964+
return F.pad(inpt, padding=params["padding"], fill=fill)
965+
966+
967+
class TestRefSegTransforms:
968+
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
969+
size = (256, 640)
970+
num_categories = 21
971+
972+
conv_fns = []
973+
if supports_pil:
974+
conv_fns.append(to_image_pil)
975+
conv_fns.extend([torch.Tensor, lambda x: x])
976+
977+
for conv_fn in conv_fns:
978+
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
979+
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
980+
981+
dp = (conv_fn(feature_image), feature_mask)
982+
dp_ref = (
983+
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
984+
to_image_pil(feature_mask),
985+
)
986+
987+
yield dp, dp_ref
988+
989+
def set_seed(self, seed=12):
990+
torch.manual_seed(seed)
991+
random.seed(seed)
992+
993+
def check(self, t, t_ref, data_kwargs=None):
994+
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
995+
996+
self.set_seed()
997+
output = t(dp)
998+
999+
self.set_seed()
1000+
expected_output = t_ref(*dp_ref)
1001+
1002+
assert_equal(output, expected_output)
1003+
1004+
@pytest.mark.parametrize(
1005+
("t_ref", "t", "data_kwargs"),
1006+
[
1007+
(
1008+
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
1009+
prototype_transforms.RandomHorizontalFlip(p=1.0),
1010+
dict(),
1011+
),
1012+
(
1013+
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
1014+
prototype_transforms.RandomHorizontalFlip(p=0.0),
1015+
dict(),
1016+
),
1017+
(
1018+
seg_transforms.RandomCrop(size=480),
1019+
prototype_transforms.Compose(
1020+
[
1021+
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
1022+
prototype_transforms.RandomCrop(size=480),
1023+
]
1024+
),
1025+
dict(),
1026+
),
1027+
(
1028+
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1029+
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1030+
dict(supports_pil=False, image_dtype=torch.float),
1031+
),
1032+
],
1033+
)
1034+
def test_common(self, t_ref, t, data_kwargs):
1035+
self.check(t, t_ref, data_kwargs)
1036+
1037+
def check_resize(self, mocker, t_ref, t):
1038+
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
1039+
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
1040+
1041+
for dp, dp_ref in self.make_datapoints():
1042+
mock.reset_mock()
1043+
mock_ref.reset_mock()
1044+
1045+
self.set_seed()
1046+
t(dp)
1047+
assert mock.call_count == 2
1048+
assert all(
1049+
actual is expected
1050+
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
1051+
)
1052+
1053+
self.set_seed()
1054+
t_ref(*dp_ref)
1055+
assert mock_ref.call_count == 2
1056+
assert all(
1057+
actual is expected
1058+
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
1059+
)
1060+
1061+
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
1062+
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]
1063+
1064+
def test_random_resize_train(self, mocker):
1065+
base_size = 520
1066+
min_size = base_size // 2
1067+
max_size = base_size * 2
1068+
1069+
randint = torch.randint
1070+
1071+
def patched_randint(a, b, *other_args, **kwargs):
1072+
if kwargs or len(other_args) > 1 or other_args[0] != ():
1073+
return randint(a, b, *other_args, **kwargs)
1074+
1075+
return random.randint(a, b)
1076+
1077+
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
1078+
# normally
1079+
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
1080+
mocker.patch(
1081+
"torchvision.prototype.transforms._geometry.torch.randint",
1082+
new=patched_randint,
1083+
)
1084+
1085+
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)
1086+
1087+
self.check_resize(mocker, t_ref, t)
1088+
1089+
def test_random_resize_eval(self, mocker):
1090+
torch.manual_seed(0)
1091+
base_size = 520
1092+
1093+
t = prototype_transforms.Resize(size=base_size, antialias=True)
1094+
1095+
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
1096+
1097+
self.check_resize(mocker, t_ref, t)

0 commit comments

Comments
 (0)