Skip to content

Commit d226e16

Browse files
committed
Updates according to the review
1 parent 72b5fd1 commit d226e16

File tree

5 files changed

+53
-58
lines changed

5 files changed

+53
-58
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
233233
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
234234

235235

236+
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
237+
# https://github.com/pytorch/vision/issues/6250
236238
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
237239
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
238240
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
@@ -437,18 +439,18 @@ def __init__(
437439

438440
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
439441

440-
if padding is not None:
441-
_check_padding_arg(padding)
442-
443-
if (padding is not None) or pad_if_needed:
444-
_check_padding_mode_arg(padding_mode)
445-
_check_fill_arg(fill)
446-
447442
self.padding = padding
448443
self.pad_if_needed = pad_if_needed
449444
self.fill = fill
450445
self.padding_mode = padding_mode
451446

447+
self._pad_op = None
448+
if self.padding is not None:
449+
self._pad_op = Pad(self.padding, fill=self.fill, padding_mode=self.padding_mode)
450+
451+
if self.pad_if_needed:
452+
self._pad_op = Pad(0, fill=self.fill, padding_mode=self.padding_mode)
453+
452454
def _get_params(self, sample: Any) -> Dict[str, Any]:
453455
image = query_image(sample)
454456
_, height, width = get_image_dimensions(image)
@@ -466,34 +468,36 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
466468
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
467469
return dict(top=top, left=left, height=output_height, width=output_width)
468470

469-
def _forward(self, flat_inputs: List[Any]) -> List[Any]:
470-
if self.padding is not None:
471-
flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
472-
473-
image = query_image(flat_inputs)
474-
_, height, width = get_image_dimensions(image)
471+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
472+
return F.crop(inpt, **params)
475473

476-
# pad the width if needed
477-
if self.pad_if_needed and width < self.size[1]:
478-
padding = [self.size[1] - width, 0]
479-
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
480-
# pad the height if needed
481-
if self.pad_if_needed and height < self.size[0]:
482-
padding = [0, self.size[0] - height]
483-
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
474+
def forward(self, *inputs: Any) -> Any:
475+
sample = inputs if len(inputs) > 1 else inputs[0]
484476

485-
params = self._get_params(flat_inputs)
477+
if self._pad_op is not None:
478+
sample = self._pad_op(sample)
486479

487-
return [F.crop(flat_input, **params) for flat_input in flat_inputs]
480+
image = query_image(sample)
481+
_, height, width = get_image_dimensions(image)
488482

489-
def forward(self, *inputs: Any) -> Any:
490-
from torch.utils._pytree import tree_flatten, tree_unflatten
483+
if self.pad_if_needed:
484+
# This check is to explicitly ensure that self._pad_op is defined
485+
if self._pad_op is None:
486+
raise RuntimeError(
487+
"Internal error, self._pad_op is None. "
488+
"Please, fill an issue about that on https://github.com/pytorch/vision/issues"
489+
)
491490

492-
sample = inputs if len(inputs) > 1 else inputs[0]
491+
# pad the width if needed
492+
if width < self.size[1]:
493+
self._pad_op.padding = [self.size[1] - width, 0]
494+
sample = self._pad_op(sample)
495+
# pad the height if needed
496+
if height < self.size[0]:
497+
self._pad_op.padding = [0, self.size[0] - height]
498+
sample = self._pad_op(sample)
493499

494-
flat_inputs, spec = tree_flatten(sample)
495-
out_flat_inputs = self._forward(flat_inputs)
496-
return tree_unflatten(out_flat_inputs, spec)
500+
return super().forward(sample)
497501

498502

499503
class RandomPerspective(_RandomApplyTransform):

torchvision/prototype/transforms/_transform.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import enum
2-
import functools
32
from typing import Any, Dict
43

4+
import PIL.Image
55
import torch
66
from torch import nn
7-
from torchvision.prototype.utils._internal import apply_recursively
7+
from torch.utils._pytree import tree_flatten, tree_unflatten
8+
from torchvision.prototype.features import _Feature
89
from torchvision.utils import _log_api_usage_once
910

1011

@@ -16,12 +17,20 @@ def __init__(self) -> None:
1617
def _get_params(self, sample: Any) -> Dict[str, Any]:
1718
return dict()
1819

19-
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
20+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
2021
raise NotImplementedError
2122

2223
def forward(self, *inputs: Any) -> Any:
2324
sample = inputs if len(inputs) > 1 else inputs[0]
24-
return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample)
25+
26+
params = self._get_params(sample)
27+
28+
flat_inputs, spec = tree_flatten(sample)
29+
transformed_types = (torch.Tensor, _Feature, PIL.Image.Image)
30+
flat_outputs = [
31+
self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs
32+
]
33+
return tree_unflatten(flat_outputs, spec)
2534

2635
def extra_repr(self) -> str:
2736
extra = []

torchvision/prototype/transforms/_utils.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Any, Iterator, Optional, Tuple, Type, Union
1+
from typing import Any, Tuple, Type, Union
22

33
import PIL.Image
44
import torch
55
from torch.utils._pytree import tree_flatten
66
from torchvision.prototype import features
7-
from torchvision.prototype.utils._internal import query_recursively
87

98
from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
109

@@ -18,22 +17,6 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
1817
raise TypeError("No image was found in the sample")
1918

2019

21-
# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier
22-
def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
23-
def fn(
24-
id: Tuple[Any, ...], input: Any
25-
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
26-
if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)):
27-
return id, input
28-
29-
return None
30-
31-
try:
32-
return next(query_recursively(fn, sample))[1]
33-
except StopIteration:
34-
raise TypeError("No image was found in the sample")
35-
36-
3720
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
3821
if isinstance(image, features.Image):
3922
channels = image.num_channels
@@ -47,16 +30,14 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
4730
return channels, height, width
4831

4932

50-
def _extract_types(sample: Any) -> Iterator[Type]:
51-
return query_recursively(lambda id, input: type(input), sample)
52-
53-
5433
def has_any(sample: Any, *types: Type) -> bool:
55-
return any(issubclass(type, types) for type in _extract_types(sample))
34+
flat_sample, _ = tree_flatten(sample)
35+
return any(issubclass(type(obj), types) for obj in flat_sample)
5636

5737

5838
def has_all(sample: Any, *types: Type) -> bool:
59-
return not bool(set(types) - set(_extract_types(sample)))
39+
flat_sample, _ = tree_flatten(sample)
40+
return not bool(set(types) - set([type(obj) for obj in flat_sample]))
6041

6142

6243
def is_simple_tensor(input: Any) -> bool:

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ def elastic_bounding_box(
814814
format: features.BoundingBoxFormat,
815815
displacement: torch.Tensor,
816816
) -> torch.Tensor:
817+
# TODO: add in docstring about approximation we are doing for grid inversion
817818
displacement = displacement.to(bounding_box.device)
818819

819820
original_shape = bounding_box.shape

torchvision/transforms/functional_pil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _parse_fill(
260260
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
261261

262262
# Process fill color for affine transforms
263-
num_bands = len(img.getbands())
263+
num_bands = get_image_num_channels(img)
264264
if fill is None:
265265
fill = 0
266266
if isinstance(fill, (int, float)) and num_bands > 1:

0 commit comments

Comments
 (0)