1
1
import math
2
2
import numbers
3
- from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , TypeVar , Union
3
+ from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , Type , TypeVar , Union
4
4
5
5
import PIL .Image
6
6
import torch
7
7
8
+ from torch .utils ._pytree import tree_flatten , tree_unflatten
8
9
from torchvision .prototype import features
9
10
from torchvision .prototype .transforms import functional as F , Transform
10
11
from torchvision .transforms .autoaugment import AutoAugmentPolicy
11
12
from torchvision .transforms .functional import InterpolationMode , pil_to_tensor , to_pil_image
12
13
13
- from ._utils import is_simple_tensor , query_chw
14
+ from ._utils import _isinstance , get_chw , is_simple_tensor
14
15
15
16
K = TypeVar ("K" )
16
17
V = TypeVar ("V" )
@@ -35,9 +36,31 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
35
36
key = keys [int (torch .randint (len (keys ), ()))]
36
37
return key , dct [key ]
37
38
38
- def _get_params (self , sample : Any ) -> Dict [str , Any ]:
39
- _ , height , width = query_chw (sample )
40
- return dict (height = height , width = width )
39
+ def _extract_image (
40
+ self ,
41
+ sample : Any ,
42
+ unsupported_types : Tuple [Type , ...] = (features .BoundingBox , features .SegmentationMask ),
43
+ ) -> Tuple [int , Union [PIL .Image .Image , torch .Tensor , features .Image ]]:
44
+ sample_flat , _ = tree_flatten (sample )
45
+ images = []
46
+ for id , inpt in enumerate (sample_flat ):
47
+ if _isinstance (inpt , (features .Image , PIL .Image .Image , is_simple_tensor )):
48
+ images .append ((id , inpt ))
49
+ elif isinstance (inpt , unsupported_types ):
50
+ raise TypeError (f"Inputs of type { type (inpt ).__name__ } are not supported by { type (self ).__name__ } ()" )
51
+
52
+ if not images :
53
+ raise TypeError ("Found no image in the sample." )
54
+ if len (images ) > 1 :
55
+ raise TypeError (
56
+ f"Auto augment transformations are only properly defined for a single image, but found { len (images )} ."
57
+ )
58
+ return images [0 ]
59
+
60
+ def _put_into_sample (self , sample : Any , id : int , item : Any ) -> Any :
61
+ sample_flat , spec = tree_flatten (sample )
62
+ sample_flat [id ] = item
63
+ return tree_unflatten (sample_flat , spec )
41
64
42
65
def _apply_image_transform (
43
66
self ,
@@ -242,34 +265,33 @@ def _get_policies(
242
265
else :
243
266
raise ValueError (f"The provided policy { policy } is not recognized." )
244
267
245
- def _get_params (self , sample : Any ) -> Dict [str , Any ]:
246
- params = super (AutoAugment , self )._get_params (sample )
247
- params ["policy" ] = self ._policies [int (torch .randint (len (self ._policies ), ()))]
248
- return params
268
+ def forward (self , * inputs : Any ) -> Any :
269
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
270
+
271
+ id , image = self ._extract_image (sample )
272
+ num_channels , height , width = get_chw (image )
249
273
250
- def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
251
- if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
252
- return inpt
274
+ policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
253
275
254
- for transform_id , probability , magnitude_idx in params [ " policy" ] :
276
+ for transform_id , probability , magnitude_idx in policy :
255
277
if not torch .rand (()) <= probability :
256
278
continue
257
279
258
280
magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
259
281
260
- magnitudes = magnitudes_fn (10 , params [ " height" ], params [ " width" ] )
282
+ magnitudes = magnitudes_fn (10 , height , width )
261
283
if magnitudes is not None :
262
284
magnitude = float (magnitudes [magnitude_idx ])
263
285
if signed and torch .rand (()) <= 0.5 :
264
286
magnitude *= - 1
265
287
else :
266
288
magnitude = 0.0
267
289
268
- inpt = self ._apply_image_transform (
269
- inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
290
+ image = self ._apply_image_transform (
291
+ image , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
270
292
)
271
293
272
- return inpt
294
+ return self . _put_into_sample ( sample , id , image )
273
295
274
296
275
297
class RandAugment (_AutoAugmentBase ):
@@ -315,26 +337,28 @@ def __init__(
315
337
self .magnitude = magnitude
316
338
self .num_magnitude_bins = num_magnitude_bins
317
339
318
- def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
319
- if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
320
- return inpt
340
+ def forward (self , * inputs : Any ) -> Any :
341
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
342
+
343
+ id , image = self ._extract_image (sample )
344
+ num_channels , height , width = get_chw (image )
321
345
322
346
for _ in range (self .num_ops ):
323
347
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
324
348
325
- magnitudes = magnitudes_fn (self .num_magnitude_bins , params [ " height" ], params [ " width" ] )
349
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
326
350
if magnitudes is not None :
327
351
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
328
352
if signed and torch .rand (()) <= 0.5 :
329
353
magnitude *= - 1
330
354
else :
331
355
magnitude = 0.0
332
356
333
- inpt = self ._apply_image_transform (
334
- inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
357
+ image = self ._apply_image_transform (
358
+ image , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
335
359
)
336
360
337
- return inpt
361
+ return self . _put_into_sample ( sample , id , image )
338
362
339
363
340
364
class TrivialAugmentWide (_AutoAugmentBase ):
@@ -370,23 +394,26 @@ def __init__(
370
394
super ().__init__ (interpolation = interpolation , fill = fill )
371
395
self .num_magnitude_bins = num_magnitude_bins
372
396
373
- def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
374
- if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
375
- return inpt
397
+ def forward (self , * inputs : Any ) -> Any :
398
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
399
+
400
+ id , image = self ._extract_image (sample )
401
+ num_channels , height , width = get_chw (image )
376
402
377
403
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
378
404
379
- magnitudes = magnitudes_fn (self .num_magnitude_bins , params [ " height" ], params [ " width" ] )
405
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
380
406
if magnitudes is not None :
381
407
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
382
408
if signed and torch .rand (()) <= 0.5 :
383
409
magnitude *= - 1
384
410
else :
385
411
magnitude = 0.0
386
412
387
- return self ._apply_image_transform (
388
- inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
413
+ image = self ._apply_image_transform (
414
+ image , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
389
415
)
416
+ return self ._put_into_sample (sample , id , image )
390
417
391
418
392
419
class AugMix (_AutoAugmentBase ):
@@ -438,13 +465,15 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
438
465
# Must be on a separate method so that we can overwrite it in tests.
439
466
return torch ._sample_dirichlet (params )
440
467
441
- def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
442
- if isinstance (inpt , features .Image ) or is_simple_tensor (inpt ):
443
- image = inpt
444
- elif isinstance (inpt , PIL .Image .Image ):
445
- image = pil_to_tensor (inpt )
446
- else :
447
- return inpt
468
+ def forward (self , * inputs : Any ) -> Any :
469
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
470
+ id , orig_image = self ._extract_image (sample )
471
+ num_channels , height , width = get_chw (orig_image )
472
+
473
+ if isinstance (orig_image , torch .Tensor ):
474
+ image = orig_image
475
+ else : # isinstance(inpt, PIL.Image.Image):
476
+ image = pil_to_tensor (orig_image )
448
477
449
478
augmentation_space = self ._AUGMENTATION_SPACE if self .all_ops else self ._PARTIAL_AUGMENTATION_SPACE
450
479
@@ -470,7 +499,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
470
499
for _ in range (depth ):
471
500
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (augmentation_space )
472
501
473
- magnitudes = magnitudes_fn (self ._PARAMETER_MAX , params [ " height" ], params [ " width" ] )
502
+ magnitudes = magnitudes_fn (self ._PARAMETER_MAX , height , width )
474
503
if magnitudes is not None :
475
504
magnitude = float (magnitudes [int (torch .randint (self .severity , ()))])
476
505
if signed and torch .rand (()) <= 0.5 :
@@ -484,9 +513,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
484
513
mix .add_ (combined_weights [:, i ].view (batch_dims ) * aug )
485
514
mix = mix .view (orig_dims ).to (dtype = image .dtype )
486
515
487
- if isinstance (inpt , features .Image ):
488
- mix = features .Image .new_like (inpt , mix )
489
- elif isinstance (inpt , PIL .Image .Image ):
516
+ if isinstance (orig_image , features .Image ):
517
+ mix = features .Image .new_like (orig_image , mix )
518
+ elif isinstance (orig_image , PIL .Image .Image ):
490
519
mix = to_pil_image (mix )
491
520
492
- return mix
521
+ return self . _put_into_sample ( sample , id , mix )
0 commit comments