Skip to content

Commit 443016d

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] cleanup prototype auto augment transforms (#6463)
Summary: * cleanup prototype auto augment transforms * remove custom fill parsing from auto augment Reviewed By: datumbox Differential Revision: D39013684 fbshipit-source-id: d0fd5329d9a672024dc0659f1153e8833e35622c
1 parent cf1ba0e commit 443016d

File tree

1 file changed

+48
-91
lines changed

1 file changed

+48
-91
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 48 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
11
import math
2-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
2+
import numbers
3+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
34

45
import PIL.Image
56
import torch
67

7-
from torch.utils._pytree import tree_flatten, tree_unflatten
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
1010
from torchvision.transforms.autoaugment import AutoAugmentPolicy
1111
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
1212

13-
from ._utils import get_chw, is_simple_tensor
13+
from ._utils import is_simple_tensor, query_chw
1414

1515
K = TypeVar("K")
1616
V = TypeVar("V")
1717

1818

19-
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
20-
sample_flat, spec = tree_flatten(sample)
21-
sample_flat[id] = item
22-
return tree_unflatten(sample_flat, spec)
23-
24-
2519
class _AutoAugmentBase(Transform):
2620
def __init__(
2721
self,
@@ -31,48 +25,19 @@ def __init__(
3125
) -> None:
3226
super().__init__()
3327
self.interpolation = interpolation
28+
29+
if not isinstance(fill, (numbers.Number, tuple, list)):
30+
raise TypeError("Got inappropriate fill arg")
3431
self.fill = fill
3532

3633
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3734
keys = tuple(dct.keys())
3835
key = keys[int(torch.randint(len(keys), ()))]
3936
return key, dct[key]
4037

41-
def _extract_image(
42-
self,
43-
sample: Any,
44-
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
45-
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
46-
sample_flat, _ = tree_flatten(sample)
47-
images = []
48-
for id, inpt in enumerate(sample_flat):
49-
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
50-
images.append((id, inpt))
51-
elif isinstance(inpt, unsupported_types):
52-
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
53-
54-
if not images:
55-
raise TypeError("Found no image in the sample.")
56-
if len(images) > 1:
57-
raise TypeError(
58-
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
59-
)
60-
return images[0]
61-
62-
def _parse_fill(
63-
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
64-
) -> Union[int, float, Sequence[int], Sequence[float]]:
65-
fill = self.fill
66-
67-
if isinstance(image, PIL.Image.Image) or fill is None:
68-
return fill
69-
70-
if isinstance(fill, (int, float)):
71-
fill = [float(fill)] * num_channels
72-
else:
73-
fill = [float(f) for f in fill]
74-
75-
return fill
38+
def _get_params(self, sample: Any) -> Dict[str, Any]:
39+
_, height, width = query_chw(sample)
40+
return dict(height=height, width=width)
7641

7742
def _apply_image_transform(
7843
self,
@@ -277,34 +242,34 @@ def _get_policies(
277242
else:
278243
raise ValueError(f"The provided policy {policy} is not recognized.")
279244

280-
def forward(self, *inputs: Any) -> Any:
281-
sample = inputs if len(inputs) > 1 else inputs[0]
245+
def _get_params(self, sample: Any) -> Dict[str, Any]:
246+
params = super(AutoAugment, self)._get_params(sample)
247+
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
248+
return params
282249

283-
id, image = self._extract_image(sample)
284-
num_channels, height, width = get_chw(image)
285-
fill = self._parse_fill(image, num_channels)
250+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
251+
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
252+
return inpt
286253

287-
policy = self._policies[int(torch.randint(len(self._policies), ()))]
288-
289-
for transform_id, probability, magnitude_idx in policy:
254+
for transform_id, probability, magnitude_idx in params["policy"]:
290255
if not torch.rand(()) <= probability:
291256
continue
292257

293258
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
294259

295-
magnitudes = magnitudes_fn(10, height, width)
260+
magnitudes = magnitudes_fn(10, params["height"], params["width"])
296261
if magnitudes is not None:
297262
magnitude = float(magnitudes[magnitude_idx])
298263
if signed and torch.rand(()) <= 0.5:
299264
magnitude *= -1
300265
else:
301266
magnitude = 0.0
302267

303-
image = self._apply_image_transform(
304-
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
268+
inpt = self._apply_image_transform(
269+
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
305270
)
306271

307-
return _put_into_sample(sample, id, image)
272+
return inpt
308273

309274

310275
class RandAugment(_AutoAugmentBase):
@@ -350,29 +315,26 @@ def __init__(
350315
self.magnitude = magnitude
351316
self.num_magnitude_bins = num_magnitude_bins
352317

353-
def forward(self, *inputs: Any) -> Any:
354-
sample = inputs if len(inputs) > 1 else inputs[0]
355-
356-
id, image = self._extract_image(sample)
357-
num_channels, height, width = get_chw(image)
358-
fill = self._parse_fill(image, num_channels)
318+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
319+
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
320+
return inpt
359321

360322
for _ in range(self.num_ops):
361323
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
362324

363-
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
325+
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
364326
if magnitudes is not None:
365327
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
366328
if signed and torch.rand(()) <= 0.5:
367329
magnitude *= -1
368330
else:
369331
magnitude = 0.0
370332

371-
image = self._apply_image_transform(
372-
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
333+
inpt = self._apply_image_transform(
334+
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
373335
)
374336

375-
return _put_into_sample(sample, id, image)
337+
return inpt
376338

377339

378340
class TrivialAugmentWide(_AutoAugmentBase):
@@ -408,25 +370,23 @@ def __init__(
408370
super().__init__(interpolation=interpolation, fill=fill)
409371
self.num_magnitude_bins = num_magnitude_bins
410372

411-
def forward(self, *inputs: Any) -> Any:
412-
sample = inputs if len(inputs) > 1 else inputs[0]
413-
414-
id, image = self._extract_image(sample)
415-
num_channels, height, width = get_chw(image)
416-
fill = self._parse_fill(image, num_channels)
373+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
374+
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
375+
return inpt
417376

418377
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
419378

420-
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
379+
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
421380
if magnitudes is not None:
422381
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
423382
if signed and torch.rand(()) <= 0.5:
424383
magnitude *= -1
425384
else:
426385
magnitude = 0.0
427386

428-
image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill)
429-
return _put_into_sample(sample, id, image)
387+
return self._apply_image_transform(
388+
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
389+
)
430390

431391

432392
class AugMix(_AutoAugmentBase):
@@ -478,16 +438,13 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
478438
# Must be on a separate method so that we can overwrite it in tests.
479439
return torch._sample_dirichlet(params)
480440

481-
def forward(self, *inputs: Any) -> Any:
482-
sample = inputs if len(inputs) > 1 else inputs[0]
483-
id, orig_image = self._extract_image(sample)
484-
num_channels, height, width = get_chw(orig_image)
485-
fill = self._parse_fill(orig_image, num_channels)
486-
487-
if isinstance(orig_image, torch.Tensor):
488-
image = orig_image
489-
else: # isinstance(inpt, PIL.Image.Image):
490-
image = pil_to_tensor(orig_image)
441+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
442+
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
443+
image = inpt
444+
elif isinstance(inpt, PIL.Image.Image):
445+
image = pil_to_tensor(inpt)
446+
else:
447+
return inpt
491448

492449
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
493450

@@ -513,7 +470,7 @@ def forward(self, *inputs: Any) -> Any:
513470
for _ in range(depth):
514471
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
515472

516-
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
473+
magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
517474
if magnitudes is not None:
518475
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
519476
if signed and torch.rand(()) <= 0.5:
@@ -522,14 +479,14 @@ def forward(self, *inputs: Any) -> Any:
522479
magnitude = 0.0
523480

524481
aug = self._apply_image_transform(
525-
aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill
482+
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
526483
)
527484
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
528485
mix = mix.view(orig_dims).to(dtype=image.dtype)
529486

530-
if isinstance(orig_image, features.Image):
531-
mix = features.Image.new_like(orig_image, mix)
532-
elif isinstance(orig_image, PIL.Image.Image):
487+
if isinstance(inpt, features.Image):
488+
mix = features.Image.new_like(inpt, mix)
489+
elif isinstance(inpt, PIL.Image.Image):
533490
mix = to_pil_image(mix)
534491

535-
return _put_into_sample(sample, id, mix)
492+
return mix

0 commit comments

Comments
 (0)