1
1
import math
2
- from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , Type , TypeVar , Union
2
+ import numbers
3
+ from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , TypeVar , Union
3
4
4
5
import PIL .Image
5
6
import torch
6
7
7
- from torch .utils ._pytree import tree_flatten , tree_unflatten
8
8
from torchvision .prototype import features
9
9
from torchvision .prototype .transforms import functional as F , Transform
10
10
from torchvision .transforms .autoaugment import AutoAugmentPolicy
11
11
from torchvision .transforms .functional import InterpolationMode , pil_to_tensor , to_pil_image
12
12
13
- from ._utils import get_chw , is_simple_tensor
13
+ from ._utils import is_simple_tensor , query_chw
14
14
15
15
K = TypeVar ("K" )
16
16
V = TypeVar ("V" )
17
17
18
18
19
- def _put_into_sample (sample : Any , id : int , item : Any ) -> Any :
20
- sample_flat , spec = tree_flatten (sample )
21
- sample_flat [id ] = item
22
- return tree_unflatten (sample_flat , spec )
23
-
24
-
25
19
class _AutoAugmentBase (Transform ):
26
20
def __init__ (
27
21
self ,
@@ -31,48 +25,19 @@ def __init__(
31
25
) -> None :
32
26
super ().__init__ ()
33
27
self .interpolation = interpolation
28
+
29
+ if not isinstance (fill , (numbers .Number , tuple , list )):
30
+ raise TypeError ("Got inappropriate fill arg" )
34
31
self .fill = fill
35
32
36
33
def _get_random_item (self , dct : Dict [K , V ]) -> Tuple [K , V ]:
37
34
keys = tuple (dct .keys ())
38
35
key = keys [int (torch .randint (len (keys ), ()))]
39
36
return key , dct [key ]
40
37
41
- def _extract_image (
42
- self ,
43
- sample : Any ,
44
- unsupported_types : Tuple [Type , ...] = (features .BoundingBox , features .SegmentationMask ),
45
- ) -> Tuple [int , Union [PIL .Image .Image , torch .Tensor , features .Image ]]:
46
- sample_flat , _ = tree_flatten (sample )
47
- images = []
48
- for id , inpt in enumerate (sample_flat ):
49
- if isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt ):
50
- images .append ((id , inpt ))
51
- elif isinstance (inpt , unsupported_types ):
52
- raise TypeError (f"Inputs of type { type (inpt ).__name__ } are not supported by { type (self ).__name__ } ()" )
53
-
54
- if not images :
55
- raise TypeError ("Found no image in the sample." )
56
- if len (images ) > 1 :
57
- raise TypeError (
58
- f"Auto augment transformations are only properly defined for a single image, but found { len (images )} ."
59
- )
60
- return images [0 ]
61
-
62
- def _parse_fill (
63
- self , image : Union [PIL .Image .Image , torch .Tensor , features .Image ], num_channels : int
64
- ) -> Union [int , float , Sequence [int ], Sequence [float ]]:
65
- fill = self .fill
66
-
67
- if isinstance (image , PIL .Image .Image ) or fill is None :
68
- return fill
69
-
70
- if isinstance (fill , (int , float )):
71
- fill = [float (fill )] * num_channels
72
- else :
73
- fill = [float (f ) for f in fill ]
74
-
75
- return fill
38
+ def _get_params (self , sample : Any ) -> Dict [str , Any ]:
39
+ _ , height , width = query_chw (sample )
40
+ return dict (height = height , width = width )
76
41
77
42
def _apply_image_transform (
78
43
self ,
@@ -277,34 +242,34 @@ def _get_policies(
277
242
else :
278
243
raise ValueError (f"The provided policy { policy } is not recognized." )
279
244
280
- def forward (self , * inputs : Any ) -> Any :
281
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
245
+ def _get_params (self , sample : Any ) -> Dict [str , Any ]:
246
+ params = super (AutoAugment , self )._get_params (sample )
247
+ params ["policy" ] = self ._policies [int (torch .randint (len (self ._policies ), ()))]
248
+ return params
282
249
283
- id , image = self . _extract_image ( sample )
284
- num_channels , height , width = get_chw ( image )
285
- fill = self . _parse_fill ( image , num_channels )
250
+ def _transform ( self , inpt : Any , params : Dict [ str , Any ]) -> Any :
251
+ if not ( isinstance ( inpt , ( features . Image , PIL . Image . Image )) or is_simple_tensor ( inpt )):
252
+ return inpt
286
253
287
- policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
288
-
289
- for transform_id , probability , magnitude_idx in policy :
254
+ for transform_id , probability , magnitude_idx in params ["policy" ]:
290
255
if not torch .rand (()) <= probability :
291
256
continue
292
257
293
258
magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
294
259
295
- magnitudes = magnitudes_fn (10 , height , width )
260
+ magnitudes = magnitudes_fn (10 , params [ " height" ], params [ " width" ] )
296
261
if magnitudes is not None :
297
262
magnitude = float (magnitudes [magnitude_idx ])
298
263
if signed and torch .rand (()) <= 0.5 :
299
264
magnitude *= - 1
300
265
else :
301
266
magnitude = 0.0
302
267
303
- image = self ._apply_image_transform (
304
- image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
268
+ inpt = self ._apply_image_transform (
269
+ inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self . fill
305
270
)
306
271
307
- return _put_into_sample ( sample , id , image )
272
+ return inpt
308
273
309
274
310
275
class RandAugment (_AutoAugmentBase ):
@@ -350,29 +315,26 @@ def __init__(
350
315
self .magnitude = magnitude
351
316
self .num_magnitude_bins = num_magnitude_bins
352
317
353
- def forward (self , * inputs : Any ) -> Any :
354
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
355
-
356
- id , image = self ._extract_image (sample )
357
- num_channels , height , width = get_chw (image )
358
- fill = self ._parse_fill (image , num_channels )
318
+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
319
+ if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
320
+ return inpt
359
321
360
322
for _ in range (self .num_ops ):
361
323
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
362
324
363
- magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
325
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , params [ " height" ], params [ " width" ] )
364
326
if magnitudes is not None :
365
327
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
366
328
if signed and torch .rand (()) <= 0.5 :
367
329
magnitude *= - 1
368
330
else :
369
331
magnitude = 0.0
370
332
371
- image = self ._apply_image_transform (
372
- image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
333
+ inpt = self ._apply_image_transform (
334
+ inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self . fill
373
335
)
374
336
375
- return _put_into_sample ( sample , id , image )
337
+ return inpt
376
338
377
339
378
340
class TrivialAugmentWide (_AutoAugmentBase ):
@@ -408,25 +370,23 @@ def __init__(
408
370
super ().__init__ (interpolation = interpolation , fill = fill )
409
371
self .num_magnitude_bins = num_magnitude_bins
410
372
411
- def forward (self , * inputs : Any ) -> Any :
412
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
413
-
414
- id , image = self ._extract_image (sample )
415
- num_channels , height , width = get_chw (image )
416
- fill = self ._parse_fill (image , num_channels )
373
+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
374
+ if not (isinstance (inpt , (features .Image , PIL .Image .Image )) or is_simple_tensor (inpt )):
375
+ return inpt
417
376
418
377
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
419
378
420
- magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
379
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , params [ " height" ], params [ " width" ] )
421
380
if magnitudes is not None :
422
381
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
423
382
if signed and torch .rand (()) <= 0.5 :
424
383
magnitude *= - 1
425
384
else :
426
385
magnitude = 0.0
427
386
428
- image = self ._apply_image_transform (image , transform_id , magnitude , interpolation = self .interpolation , fill = fill )
429
- return _put_into_sample (sample , id , image )
387
+ return self ._apply_image_transform (
388
+ inpt , transform_id , magnitude , interpolation = self .interpolation , fill = self .fill
389
+ )
430
390
431
391
432
392
class AugMix (_AutoAugmentBase ):
@@ -478,16 +438,13 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
478
438
# Must be on a separate method so that we can overwrite it in tests.
479
439
return torch ._sample_dirichlet (params )
480
440
481
- def forward (self , * inputs : Any ) -> Any :
482
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
483
- id , orig_image = self ._extract_image (sample )
484
- num_channels , height , width = get_chw (orig_image )
485
- fill = self ._parse_fill (orig_image , num_channels )
486
-
487
- if isinstance (orig_image , torch .Tensor ):
488
- image = orig_image
489
- else : # isinstance(inpt, PIL.Image.Image):
490
- image = pil_to_tensor (orig_image )
441
+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
442
+ if isinstance (inpt , features .Image ) or is_simple_tensor (inpt ):
443
+ image = inpt
444
+ elif isinstance (inpt , PIL .Image .Image ):
445
+ image = pil_to_tensor (inpt )
446
+ else :
447
+ return inpt
491
448
492
449
augmentation_space = self ._AUGMENTATION_SPACE if self .all_ops else self ._PARTIAL_AUGMENTATION_SPACE
493
450
@@ -513,7 +470,7 @@ def forward(self, *inputs: Any) -> Any:
513
470
for _ in range (depth ):
514
471
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (augmentation_space )
515
472
516
- magnitudes = magnitudes_fn (self ._PARAMETER_MAX , height , width )
473
+ magnitudes = magnitudes_fn (self ._PARAMETER_MAX , params [ " height" ], params [ " width" ] )
517
474
if magnitudes is not None :
518
475
magnitude = float (magnitudes [int (torch .randint (self .severity , ()))])
519
476
if signed and torch .rand (()) <= 0.5 :
@@ -522,14 +479,14 @@ def forward(self, *inputs: Any) -> Any:
522
479
magnitude = 0.0
523
480
524
481
aug = self ._apply_image_transform (
525
- aug , transform_id , magnitude , interpolation = self .interpolation , fill = fill
482
+ aug , transform_id , magnitude , interpolation = self .interpolation , fill = self . fill
526
483
)
527
484
mix .add_ (combined_weights [:, i ].view (batch_dims ) * aug )
528
485
mix = mix .view (orig_dims ).to (dtype = image .dtype )
529
486
530
- if isinstance (orig_image , features .Image ):
531
- mix = features .Image .new_like (orig_image , mix )
532
- elif isinstance (orig_image , PIL .Image .Image ):
487
+ if isinstance (inpt , features .Image ):
488
+ mix = features .Image .new_like (inpt , mix )
489
+ elif isinstance (inpt , PIL .Image .Image ):
533
490
mix = to_pil_image (mix )
534
491
535
- return _put_into_sample ( sample , id , mix )
492
+ return mix
0 commit comments