Skip to content

Commit 8a188ad

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add tests for transform presets, and various fixes (#7223)
Reviewed By: vmoens Differential Revision: D44416274 fbshipit-source-id: 87f1e0dd1b8bafc383cef15f31391d7c3c0ed6d3
1 parent 5b32fa6 commit 8a188ad

File tree

4 files changed

+179
-11
lines changed

4 files changed

+179
-11
lines changed

test/test_prototype_transforms.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import re
3+
from collections import defaultdict
34

45
import numpy as np
56

@@ -1988,3 +1989,154 @@ def test__transform(self, inpt):
19881989
assert type(output) is type(inpt)
19891990
assert output.shape[-4] == num_samples
19901991
assert output.dtype == inpt.dtype
1992+
1993+
1994+
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
1995+
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
1996+
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
1997+
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
1998+
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
1999+
2000+
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
2001+
if image_type is PIL.Image:
2002+
image = to_pil_image(image[0])
2003+
elif image_type is torch.Tensor:
2004+
image = image.as_subclass(torch.Tensor)
2005+
assert is_simple_tensor(image)
2006+
2007+
label = 1 if label_type is int else torch.tensor([1])
2008+
2009+
if dataset_return_type is dict:
2010+
sample = {
2011+
"image": image,
2012+
"label": label,
2013+
}
2014+
else:
2015+
sample = image, label
2016+
2017+
t = transforms.Compose(
2018+
[
2019+
transforms.RandomResizedCrop((224, 224)),
2020+
transforms.RandomHorizontalFlip(p=1),
2021+
transforms.RandAugment(),
2022+
transforms.TrivialAugmentWide(),
2023+
transforms.AugMix(),
2024+
transforms.AutoAugment(),
2025+
to_tensor(),
2026+
# TODO: ConvertImageDtype is a pass-through on PIL images, is that
2027+
# intended? This results in a failure if we convert to tensor after
2028+
# it, because the image would still be uint8 which make Normalize
2029+
# fail.
2030+
transforms.ConvertImageDtype(torch.float),
2031+
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
2032+
transforms.RandomErasing(p=1),
2033+
]
2034+
)
2035+
2036+
out = t(sample)
2037+
2038+
assert type(out) == type(sample)
2039+
2040+
if dataset_return_type is tuple:
2041+
out_image, out_label = out
2042+
else:
2043+
assert out.keys() == sample.keys()
2044+
out_image, out_label = out.values()
2045+
2046+
assert out_image.shape[-2:] == (224, 224)
2047+
assert out_label == label
2048+
2049+
2050+
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
2051+
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
2052+
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
2053+
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
2054+
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
2055+
if data_augmentation == "hflip":
2056+
t = [
2057+
transforms.RandomHorizontalFlip(p=1),
2058+
to_tensor(),
2059+
transforms.ConvertImageDtype(torch.float),
2060+
]
2061+
elif data_augmentation == "lsj":
2062+
t = [
2063+
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
2064+
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're
2065+
# leaving FixedSizeCrop in prototype for now, and it expects Label
2066+
# classes which we won't release yet.
2067+
# transforms.FixedSizeCrop(
2068+
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
2069+
# ),
2070+
transforms.RandomCrop((1024, 1024), pad_if_needed=True),
2071+
transforms.RandomHorizontalFlip(p=1),
2072+
to_tensor(),
2073+
transforms.ConvertImageDtype(torch.float),
2074+
]
2075+
elif data_augmentation == "multiscale":
2076+
t = [
2077+
transforms.RandomShortestSize(
2078+
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
2079+
),
2080+
transforms.RandomHorizontalFlip(p=1),
2081+
to_tensor(),
2082+
transforms.ConvertImageDtype(torch.float),
2083+
]
2084+
elif data_augmentation == "ssd":
2085+
t = [
2086+
transforms.RandomPhotometricDistort(p=1),
2087+
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
2088+
# TODO: put back IoUCrop once we remove its hard requirement for Labels
2089+
# transforms.RandomIoUCrop(),
2090+
transforms.RandomHorizontalFlip(p=1),
2091+
to_tensor(),
2092+
transforms.ConvertImageDtype(torch.float),
2093+
]
2094+
elif data_augmentation == "ssdlite":
2095+
t = [
2096+
# TODO: put back IoUCrop once we remove its hard requirement for Labels
2097+
# transforms.RandomIoUCrop(),
2098+
transforms.RandomHorizontalFlip(p=1),
2099+
to_tensor(),
2100+
transforms.ConvertImageDtype(torch.float),
2101+
]
2102+
t = transforms.Compose(t)
2103+
2104+
num_boxes = 5
2105+
H = W = 250
2106+
2107+
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
2108+
if image_type is PIL.Image:
2109+
image = to_pil_image(image[0])
2110+
elif image_type is torch.Tensor:
2111+
image = image.as_subclass(torch.Tensor)
2112+
assert is_simple_tensor(image)
2113+
2114+
label = torch.randint(0, 10, size=(num_boxes,))
2115+
if label_type is list:
2116+
label = label.tolist()
2117+
2118+
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
2119+
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
2120+
boxes[:, 2:] += boxes[:, :2]
2121+
boxes = boxes.clamp(min=0, max=min(H, W))
2122+
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W))
2123+
2124+
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
2125+
2126+
sample = {
2127+
"image": image,
2128+
"label": label,
2129+
"boxes": boxes,
2130+
"masks": masks,
2131+
}
2132+
2133+
out = t(sample)
2134+
2135+
if to_tensor is transforms.ToTensor and image_type is not datapoints.Image:
2136+
assert is_simple_tensor(out["image"])
2137+
else:
2138+
assert isinstance(out["image"], datapoints.Image)
2139+
assert isinstance(out["label"], type(sample["label"]))
2140+
2141+
out["label"] = torch.tensor(out["label"])
2142+
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ def _flatten_and_extract_image_or_video(
3737
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
3838
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
3939
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
40+
needs_transform_list = self._needs_transform_list(flat_inputs)
4041

4142
image_or_videos = []
42-
for idx, inpt in enumerate(flat_inputs):
43-
if check_type(
43+
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
44+
if needs_transform and check_type(
4445
inpt,
4546
(
4647
datapoints.Image,

torchvision/prototype/transforms/_color.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def _permute_channels(
169169
if isinstance(orig_inpt, PIL.Image.Image):
170170
inpt = F.pil_to_tensor(inpt)
171171

172-
output = inpt[..., permutation, :, :]
172+
# TODO: Find a better fix than as_subclass???
173+
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
173174

174175
if isinstance(orig_inpt, PIL.Image.Image):
175176
output = F.to_image_pil(output)

torchvision/prototype/transforms/_transform.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,19 @@ def forward(self, *inputs: Any) -> Any:
3636

3737
self._check_inputs(flat_inputs)
3838

39-
params = self._get_params(flat_inputs)
39+
needs_transform_list = self._needs_transform_list(flat_inputs)
40+
params = self._get_params(
41+
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
42+
)
4043

44+
flat_outputs = [
45+
self._transform(inpt, params) if needs_transform else inpt
46+
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
47+
]
48+
49+
return tree_unflatten(flat_outputs, spec)
50+
51+
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
4152
# Below is a heuristic on how to deal with simple tensor inputs:
4253
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
4354
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
@@ -53,7 +64,8 @@ def forward(self, *inputs: Any) -> Any:
5364
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
5465
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
5566
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
56-
flat_outputs = []
67+
68+
needs_transform_list = []
5769
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
5870
for inpt in flat_inputs:
5971
needs_transform = True
@@ -65,10 +77,8 @@ def forward(self, *inputs: Any) -> Any:
6577
transform_simple_tensor = False
6678
else:
6779
needs_transform = False
68-
69-
flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt)
70-
71-
return tree_unflatten(flat_outputs, spec)
80+
needs_transform_list.append(needs_transform)
81+
return needs_transform_list
7282

7383
def extra_repr(self) -> str:
7484
extra = []
@@ -159,10 +169,14 @@ def forward(self, *inputs: Any) -> Any:
159169
if torch.rand(1) >= self.p:
160170
return inputs
161171

162-
params = self._get_params(flat_inputs)
172+
needs_transform_list = self._needs_transform_list(flat_inputs)
173+
params = self._get_params(
174+
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
175+
)
163176

164177
flat_outputs = [
165-
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
178+
self._transform(inpt, params) if needs_transform else inpt
179+
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
166180
]
167181

168182
return tree_unflatten(flat_outputs, spec)

0 commit comments

Comments
 (0)