Skip to content

Commit cf76a5a

Browse files
YosuaMichaeldatumbox
authored andcommitted
[fbsync] Cleanup of prototype transforms (#6492)
Summary: * fix passtrough on transforms and add dispatchers for five and ten crop * Revert "cleanup prototype auto augment transforms (#6463)" This reverts commit d8025b9. * use legacy kernels in deprecated Grayscale and RandomGrayscale transforms * fix default type for Lambda transform * fix default type for ToDtype transform * move simple_tensor to features module * [skip ci] * Revert "move simple_tensor to features module" This reverts commit 7043b6e. * cleanup * reinstate valid AA changes * address review * Fix linter Reviewed By: NicolasHug Differential Revision: D39131014 fbshipit-source-id: 0237a0e2a8256cf7ec5f5bc3b529e471c465ea04 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 69c2a08 commit cf76a5a

File tree

6 files changed

+114
-73
lines changed

6 files changed

+114
-73
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import math
22
import numbers
3-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
3+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
44

55
import PIL.Image
66
import torch
77

8+
from torch.utils._pytree import tree_flatten, tree_unflatten
89
from torchvision.prototype import features
910
from torchvision.prototype.transforms import functional as F, Transform
1011
from torchvision.transforms.autoaugment import AutoAugmentPolicy
1112
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
1213

13-
from ._utils import is_simple_tensor, query_chw
14+
from ._utils import _isinstance, get_chw, is_simple_tensor
1415

1516
K = TypeVar("K")
1617
V = TypeVar("V")
@@ -35,9 +36,31 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3536
key = keys[int(torch.randint(len(keys), ()))]
3637
return key, dct[key]
3738

38-
def _get_params(self, sample: Any) -> Dict[str, Any]:
39-
_, height, width = query_chw(sample)
40-
return dict(height=height, width=width)
39+
def _extract_image(
40+
self,
41+
sample: Any,
42+
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
43+
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
44+
sample_flat, _ = tree_flatten(sample)
45+
images = []
46+
for id, inpt in enumerate(sample_flat):
47+
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)):
48+
images.append((id, inpt))
49+
elif isinstance(inpt, unsupported_types):
50+
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
51+
52+
if not images:
53+
raise TypeError("Found no image in the sample.")
54+
if len(images) > 1:
55+
raise TypeError(
56+
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
57+
)
58+
return images[0]
59+
60+
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
61+
sample_flat, spec = tree_flatten(sample)
62+
sample_flat[id] = item
63+
return tree_unflatten(sample_flat, spec)
4164

4265
def _apply_image_transform(
4366
self,
@@ -242,34 +265,33 @@ def _get_policies(
242265
else:
243266
raise ValueError(f"The provided policy {policy} is not recognized.")
244267

245-
def _get_params(self, sample: Any) -> Dict[str, Any]:
246-
params = super(AutoAugment, self)._get_params(sample)
247-
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
248-
return params
268+
def forward(self, *inputs: Any) -> Any:
269+
sample = inputs if len(inputs) > 1 else inputs[0]
270+
271+
id, image = self._extract_image(sample)
272+
num_channels, height, width = get_chw(image)
249273

250-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
251-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
252-
return inpt
274+
policy = self._policies[int(torch.randint(len(self._policies), ()))]
253275

254-
for transform_id, probability, magnitude_idx in params["policy"]:
276+
for transform_id, probability, magnitude_idx in policy:
255277
if not torch.rand(()) <= probability:
256278
continue
257279

258280
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
259281

260-
magnitudes = magnitudes_fn(10, params["height"], params["width"])
282+
magnitudes = magnitudes_fn(10, height, width)
261283
if magnitudes is not None:
262284
magnitude = float(magnitudes[magnitude_idx])
263285
if signed and torch.rand(()) <= 0.5:
264286
magnitude *= -1
265287
else:
266288
magnitude = 0.0
267289

268-
inpt = self._apply_image_transform(
269-
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
290+
image = self._apply_image_transform(
291+
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
270292
)
271293

272-
return inpt
294+
return self._put_into_sample(sample, id, image)
273295

274296

275297
class RandAugment(_AutoAugmentBase):
@@ -315,26 +337,28 @@ def __init__(
315337
self.magnitude = magnitude
316338
self.num_magnitude_bins = num_magnitude_bins
317339

318-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
319-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
320-
return inpt
340+
def forward(self, *inputs: Any) -> Any:
341+
sample = inputs if len(inputs) > 1 else inputs[0]
342+
343+
id, image = self._extract_image(sample)
344+
num_channels, height, width = get_chw(image)
321345

322346
for _ in range(self.num_ops):
323347
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
324348

325-
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
349+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
326350
if magnitudes is not None:
327351
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
328352
if signed and torch.rand(()) <= 0.5:
329353
magnitude *= -1
330354
else:
331355
magnitude = 0.0
332356

333-
inpt = self._apply_image_transform(
334-
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
357+
image = self._apply_image_transform(
358+
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
335359
)
336360

337-
return inpt
361+
return self._put_into_sample(sample, id, image)
338362

339363

340364
class TrivialAugmentWide(_AutoAugmentBase):
@@ -370,23 +394,26 @@ def __init__(
370394
super().__init__(interpolation=interpolation, fill=fill)
371395
self.num_magnitude_bins = num_magnitude_bins
372396

373-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
374-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
375-
return inpt
397+
def forward(self, *inputs: Any) -> Any:
398+
sample = inputs if len(inputs) > 1 else inputs[0]
399+
400+
id, image = self._extract_image(sample)
401+
num_channels, height, width = get_chw(image)
376402

377403
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
378404

379-
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
405+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
380406
if magnitudes is not None:
381407
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
382408
if signed and torch.rand(()) <= 0.5:
383409
magnitude *= -1
384410
else:
385411
magnitude = 0.0
386412

387-
return self._apply_image_transform(
388-
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
413+
image = self._apply_image_transform(
414+
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
389415
)
416+
return self._put_into_sample(sample, id, image)
390417

391418

392419
class AugMix(_AutoAugmentBase):
@@ -438,13 +465,15 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
438465
# Must be on a separate method so that we can overwrite it in tests.
439466
return torch._sample_dirichlet(params)
440467

441-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
442-
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
443-
image = inpt
444-
elif isinstance(inpt, PIL.Image.Image):
445-
image = pil_to_tensor(inpt)
446-
else:
447-
return inpt
468+
def forward(self, *inputs: Any) -> Any:
469+
sample = inputs if len(inputs) > 1 else inputs[0]
470+
id, orig_image = self._extract_image(sample)
471+
num_channels, height, width = get_chw(orig_image)
472+
473+
if isinstance(orig_image, torch.Tensor):
474+
image = orig_image
475+
else: # isinstance(inpt, PIL.Image.Image):
476+
image = pil_to_tensor(orig_image)
448477

449478
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
450479

@@ -470,7 +499,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
470499
for _ in range(depth):
471500
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
472501

473-
magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
502+
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
474503
if magnitudes is not None:
475504
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
476505
if signed and torch.rand(()) <= 0.5:
@@ -484,9 +513,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
484513
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
485514
mix = mix.view(orig_dims).to(dtype=image.dtype)
486515

487-
if isinstance(inpt, features.Image):
488-
mix = features.Image.new_like(inpt, mix)
489-
elif isinstance(inpt, PIL.Image.Image):
516+
if isinstance(orig_image, features.Image):
517+
mix = features.Image.new_like(orig_image, mix)
518+
elif isinstance(orig_image, PIL.Image.Image):
490519
mix = to_pil_image(mix)
491520

492-
return mix
521+
return self._put_into_sample(sample, id, mix)

torchvision/prototype/transforms/_deprecated.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import numpy as np
55
import PIL.Image
66
import torch
7-
import torchvision.prototype.transforms.functional as F
7+
88
from torchvision.prototype import features
9-
from torchvision.prototype.features import ColorSpace
109
from torchvision.prototype.transforms import Transform
1110
from torchvision.transforms import functional as _F
1211
from typing_extensions import Literal
1312

1413
from ._transform import _RandomApplyTransform
15-
from ._utils import is_simple_tensor
14+
from ._utils import is_simple_tensor, query_chw
1615

1716

1817
class ToTensor(Transform):
@@ -59,6 +58,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
5958

6059

6160
class Grayscale(Transform):
61+
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
62+
6263
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
6364
deprecation_msg = (
6465
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
@@ -81,13 +82,12 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
8182
self.num_output_channels = num_output_channels
8283

8384
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
84-
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
85-
if self.num_output_channels == 3:
86-
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
87-
return output
85+
return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
8886

8987

9088
class RandomGrayscale(_RandomApplyTransform):
89+
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
90+
9191
def __init__(self, p: float = 0.1) -> None:
9292
warnings.warn(
9393
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
@@ -103,6 +103,9 @@ def __init__(self, p: float = 0.1) -> None:
103103

104104
super().__init__(p=p)
105105

106+
def _get_params(self, sample: Any) -> Dict[str, Any]:
107+
num_input_channels, _, _ = query_chw(sample)
108+
return dict(num_input_channels=num_input_channels)
109+
106110
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
107-
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
108-
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
111+
return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])

torchvision/prototype/transforms/_geometry.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,14 @@ class FiveCrop(Transform):
156156
torch.Size([5])
157157
"""
158158

159+
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
160+
159161
def __init__(self, size: Union[int, Sequence[int]]) -> None:
160162
super().__init__()
161163
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
162164

163165
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
164-
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a
165-
# list here to align it with TenCrop.
166-
if isinstance(inpt, features.Image):
167-
output = F.five_crop_image_tensor(inpt, self.size)
168-
return tuple(features.Image.new_like(inpt, o) for o in output)
169-
elif is_simple_tensor(inpt):
170-
return F.five_crop_image_tensor(inpt, self.size)
171-
elif isinstance(inpt, PIL.Image.Image):
172-
return F.five_crop_image_pil(inpt, self.size)
173-
else:
174-
return inpt
166+
return F.five_crop(inpt, self.size)
175167

176168
def forward(self, *inputs: Any) -> Any:
177169
sample = inputs if len(inputs) > 1 else inputs[0]
@@ -185,21 +177,15 @@ class TenCrop(Transform):
185177
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
186178
"""
187179

180+
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
181+
188182
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
189183
super().__init__()
190184
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
191185
self.vertical_flip = vertical_flip
192186

193187
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
194-
if isinstance(inpt, features.Image):
195-
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
196-
return [features.Image.new_like(inpt, o) for o in output]
197-
elif is_simple_tensor(inpt):
198-
return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
199-
elif isinstance(inpt, PIL.Image.Image):
200-
return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)
201-
else:
202-
return inpt
188+
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
203189

204190
def forward(self, *inputs: Any) -> Any:
205191
sample = inputs if len(inputs) > 1 else inputs[0]

torchvision/prototype/transforms/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Lambda(Transform):
2222
def __init__(self, fn: Callable[[Any], Any], *types: Type):
2323
super().__init__()
2424
self.fn = fn
25-
self.types = types
25+
self.types = types or (object,)
2626

2727
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
2828
if type(inpt) in self.types:
@@ -137,7 +137,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
137137
class ToDtype(Lambda):
138138
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
139139
self.dtype = dtype
140-
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
140+
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
141141

142142
def extra_repr(self) -> str:
143143
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
elastic_image_tensor,
6666
elastic_segmentation_mask,
6767
elastic_transform,
68+
five_crop,
6869
five_crop_image_pil,
6970
five_crop_image_tensor,
7071
horizontal_flip,
@@ -97,6 +98,7 @@
9798
rotate_image_pil,
9899
rotate_image_tensor,
99100
rotate_segmentation_mask,
101+
ten_crop,
100102
ten_crop_image_pil,
101103
ten_crop_image_tensor,
102104
vertical_flip,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,17 @@ def five_crop_image_pil(
10781078
return tl, tr, bl, br, center
10791079

10801080

1081+
def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]:
1082+
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop`
1083+
if isinstance(inpt, torch.Tensor):
1084+
output = five_crop_image_tensor(inpt, size)
1085+
if isinstance(inpt, features.Image):
1086+
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
1087+
return output
1088+
else: # isinstance(inpt, PIL.Image.Image):
1089+
return five_crop_image_pil(inpt, size)
1090+
1091+
10811092
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
10821093
tl, tr, bl, br, center = five_crop_image_tensor(img, size)
10831094

@@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
11021113
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
11031114

11041115
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
1116+
1117+
1118+
def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]:
1119+
if isinstance(inpt, torch.Tensor):
1120+
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
1121+
if isinstance(inpt, features.Image):
1122+
output = [features.Image.new_like(inpt, item) for item in output]
1123+
return output
1124+
else: # isinstance(inpt, PIL.Image.Image):
1125+
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)

0 commit comments

Comments
 (0)