Skip to content

Commit 2ee15c7

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Fixes Transform._transformed_types and torch.Tensor (#6487)
Summary: * Fixes unexpected behaviour with Transform._transformed_types and torch.Tensor * Make code consistent to has_any, has_all implementation * Fixed failing flake8 check Reviewed By: NicolasHug Differential Revision: D39131010 fbshipit-source-id: 376289d1b0854acdf76f4495a1eabf940058551b
1 parent 062052b commit 2ee15c7

File tree

6 files changed

+30
-15
lines changed

6 files changed

+30
-15
lines changed

test/test_prototype_transforms.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,18 @@ def test_random_resized_crop(self, transform, input):
225225
)
226226
]
227227
)
228-
def test_convertolor_space(self, transform, input):
228+
def test_convert_color_space(self, transform, input):
229229
transform(input)
230230

231+
def test_convert_color_space_unsupported_types(self):
232+
transform = transforms.ConvertColorSpace(
233+
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
234+
)
235+
236+
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
237+
output = transform(inpt)
238+
assert output is inpt
239+
231240

232241
@pytest.mark.parametrize("p", [0.0, 1.0])
233242
class TestRandomHorizontalFlip:

torchvision/prototype/transforms/_deprecated.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import PIL.Image
6-
import torch
76
import torchvision.prototype.transforms.functional as F
87
from torchvision.prototype import features
98
from torchvision.prototype.features import ColorSpace
@@ -18,7 +17,7 @@
1817
class ToTensor(Transform):
1918

2019
# Updated transformed types for ToTensor
21-
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
20+
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
2221

2322
def __init__(self) -> None:
2423
warnings.warn(
@@ -52,7 +51,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
5251
class ToPILImage(Transform):
5352

5453
# Updated transformed types for ToPILImage
55-
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
54+
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
5655

5756
def __init__(self, mode: Optional[str] = None) -> None:
5857
warnings.warn(

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4242

4343
class ConvertColorSpace(Transform):
4444
# F.convert_color_space does NOT handle `_Feature`'s in general
45-
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)
45+
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
4646

4747
def __init__(
4848
self,

torchvision/prototype/transforms/_transform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import enum
2-
from typing import Any, Dict, Tuple, Type
2+
from typing import Any, Callable, Dict, Tuple, Type, Union
33

44
import PIL.Image
55
import torch
66
from torch import nn
77
from torch.utils._pytree import tree_flatten, tree_unflatten
88
from torchvision.prototype.features import _Feature
9+
from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor
910
from torchvision.utils import _log_api_usage_once
1011

1112

1213
class Transform(nn.Module):
1314

1415
# Class attribute defining transformed types. Other types are passed-through without any transformation
15-
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image)
16+
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image)
1617

1718
def __init__(self) -> None:
1819
super().__init__()
@@ -31,7 +32,8 @@ def forward(self, *inputs: Any) -> Any:
3132

3233
flat_inputs, spec = tree_flatten(sample)
3334
flat_outputs = [
34-
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
35+
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
36+
for inpt in flat_inputs
3537
]
3638
return tree_unflatten(flat_outputs, spec)
3739

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import PIL.Image
55

6-
import torch
76
from torch.nn.functional import one_hot
87
from torchvision.prototype import features
98
from torchvision.prototype.transforms import functional as F, Transform
@@ -44,7 +43,7 @@ def extra_repr(self) -> str:
4443
class ToImageTensor(Transform):
4544

4645
# Updated transformed types for ToImageTensor
47-
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
46+
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
4847

4948
def __init__(self, *, copy: bool = False) -> None:
5049
super().__init__()
@@ -61,7 +60,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
6160
class ToImagePIL(Transform):
6261

6362
# Updated transformed types for ToImagePIL
64-
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
63+
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
6564

6665
def __init__(self, *, mode: Optional[str] = None) -> None:
6766
super().__init__()

torchvision/prototype/transforms/_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,18 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
4545
return chws.pop()
4646

4747

48+
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
49+
for type_or_check in types_or_checks:
50+
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
51+
return True
52+
return False
53+
54+
4855
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
4956
flat_sample, _ = tree_flatten(sample)
50-
for type_or_check in types_or_checks:
51-
for obj in flat_sample:
52-
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
53-
return True
57+
for obj in flat_sample:
58+
if _isinstance(obj, types_or_checks):
59+
return True
5460
return False
5561

5662

0 commit comments

Comments
 (0)