Skip to content

Commit 2a0eea8

Browse files
pmeiervfdev-5
andauthored
replace query_recursively with pytree implementation (#6434)
* replace query_recursively with pytree implementation * simplify Co-authored-by: vfdev <[email protected]>
1 parent 961d97b commit 2a0eea8

File tree

2 files changed

+14
-40
lines changed

2 files changed

+14
-40
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,23 @@
33

44
import PIL.Image
55
import torch
6+
7+
from torch.utils._pytree import tree_flatten, tree_unflatten
68
from torchvision.prototype import features
79
from torchvision.prototype.transforms import functional as F, Transform
8-
from torchvision.prototype.utils._internal import query_recursively
910
from torchvision.transforms.autoaugment import AutoAugmentPolicy
1011
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
1112

12-
from ._utils import get_image_dimensions
13+
from ._utils import get_image_dimensions, is_simple_tensor
1314

1415
K = TypeVar("K")
1516
V = TypeVar("V")
1617

1718

18-
def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
19-
if not id:
20-
return item
21-
22-
parent = sample
23-
for key in id[:-1]:
24-
parent = parent[key]
25-
26-
parent[id[-1]] = item
27-
return sample
19+
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
20+
sample_flat, spec = tree_flatten(sample)
21+
sample_flat[id] = item
22+
return tree_unflatten(sample_flat, spec)
2823

2924

3025
class _AutoAugmentBase(Transform):
@@ -47,18 +42,15 @@ def _extract_image(
4742
self,
4843
sample: Any,
4944
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
50-
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
51-
def fn(
52-
id: Tuple[Any, ...], inpt: Any
53-
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
54-
if type(inpt) in {torch.Tensor, features.Image} or isinstance(inpt, PIL.Image.Image):
55-
return id, inpt
45+
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
46+
sample_flat, _ = tree_flatten(sample)
47+
images = []
48+
for id, inpt in enumerate(sample_flat):
49+
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
50+
images.append((id, inpt))
5651
elif isinstance(inpt, unsupported_types):
5752
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
58-
else:
59-
return None
6053

61-
images = list(query_recursively(fn, sample))
6254
if not images:
6355
raise TypeError("Found no image in the sample.")
6456
if len(images) > 1:

torchvision/prototype/utils/_internal.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import io
44
import mmap
55
import platform
6-
from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union
6+
from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union
77

88
import numpy as np
99
import torch
@@ -14,7 +14,6 @@
1414
"add_suggestion",
1515
"fromfile",
1616
"ReadOnlyTensorBuffer",
17-
"query_recursively",
1817
]
1918

2019

@@ -125,20 +124,3 @@ def read(self, size: int = -1) -> bytes:
125124
cursor = self.tell()
126125
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
127126
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
128-
129-
130-
def query_recursively(
131-
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
132-
) -> Iterator[D]:
133-
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
134-
# "a" == "a"[0][0]...
135-
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
136-
for idx, item in enumerate(obj):
137-
yield from query_recursively(fn, item, id=(*id, idx))
138-
elif isinstance(obj, collections.abc.Mapping):
139-
for key, item in obj.items():
140-
yield from query_recursively(fn, item, id=(*id, key))
141-
else:
142-
result = fn(id, obj)
143-
if result is not None:
144-
yield result

0 commit comments

Comments
 (0)