3
3
4
4
import PIL .Image
5
5
import torch
6
+
7
+ from torch .utils ._pytree import tree_flatten , tree_unflatten
6
8
from torchvision .prototype import features
7
9
from torchvision .prototype .transforms import functional as F , Transform
8
- from torchvision .prototype .utils ._internal import query_recursively
9
10
from torchvision .transforms .autoaugment import AutoAugmentPolicy
10
11
from torchvision .transforms .functional import InterpolationMode , pil_to_tensor , to_pil_image
11
12
12
- from ._utils import get_image_dimensions
13
+ from ._utils import get_image_dimensions , is_simple_tensor
13
14
14
15
K = TypeVar ("K" )
15
16
V = TypeVar ("V" )
16
17
17
18
18
- def _put_into_sample (sample : Any , id : Tuple [Any , ...], item : Any ) -> Any :
19
- if not id :
20
- return item
21
-
22
- parent = sample
23
- for key in id [:- 1 ]:
24
- parent = parent [key ]
25
-
26
- parent [id [- 1 ]] = item
27
- return sample
19
+ def _put_into_sample (sample : Any , id : int , item : Any ) -> Any :
20
+ sample_flat , spec = tree_flatten (sample )
21
+ sample_flat [id ] = item
22
+ return tree_unflatten (sample_flat , spec )
28
23
29
24
30
25
class _AutoAugmentBase (Transform ):
@@ -47,18 +42,15 @@ def _extract_image(
47
42
self ,
48
43
sample : Any ,
49
44
unsupported_types : Tuple [Type , ...] = (features .BoundingBox , features .SegmentationMask ),
50
- ) -> Tuple [Tuple [ Any , ...] , Union [PIL .Image .Image , torch .Tensor , features .Image ]]:
51
- def fn (
52
- id : Tuple [ Any , ...], inpt : Any
53
- ) -> Optional [ Tuple [ Tuple [ Any , ...], Union [ PIL . Image . Image , torch . Tensor , features . Image ]]] :
54
- if type (inpt ) in { torch . Tensor , features .Image } or isinstance ( inpt , PIL .Image .Image ):
55
- return id , inpt
45
+ ) -> Tuple [int , Union [PIL .Image .Image , torch .Tensor , features .Image ]]:
46
+ sample_flat , _ = tree_flatten ( sample )
47
+ images = []
48
+ for id , inpt in enumerate ( sample_flat ) :
49
+ if isinstance (inpt , ( features .Image , PIL .Image .Image )) or is_simple_tensor ( inpt ):
50
+ images . append (( id , inpt ))
56
51
elif isinstance (inpt , unsupported_types ):
57
52
raise TypeError (f"Inputs of type { type (inpt ).__name__ } are not supported by { type (self ).__name__ } ()" )
58
- else :
59
- return None
60
53
61
- images = list (query_recursively (fn , sample ))
62
54
if not images :
63
55
raise TypeError ("Found no image in the sample." )
64
56
if len (images ) > 1 :
0 commit comments