Skip to content

Commit 0e7de28

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] expand has_any and has_all to also accept check callables (#6447)
Summary: * expand has_any and has_all to also accept check callables * add test and fix has_all * add support for simple tensor images to CutMix, MixUp and RandomIoUCrop * remove TODO * remove pythonic syntax sugar * simplify * use concreate examples in test rather than abstract ones * simplify further Reviewed By: datumbox Differential Revision: D39013675 fbshipit-source-id: 6cd68d471b7cf94192284cdf7948c87ed570e6af
1 parent a2b8ed2 commit 0e7de28

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import PIL.Image
2+
import pytest
3+
4+
import torch
5+
6+
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
7+
8+
from torchvision.prototype import features
9+
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
10+
from torchvision.prototype.transforms.functional import to_image_pil
11+
12+
13+
IMAGE = make_image(color_space=features.ColorSpace.RGB)
14+
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
15+
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
16+
17+
18+
@pytest.mark.parametrize(
19+
("sample", "types", "expected"),
20+
[
21+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
22+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
23+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
24+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
25+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
26+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
27+
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
28+
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
29+
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
30+
(
31+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
32+
(features.Image, features.BoundingBox, features.SegmentationMask),
33+
True,
34+
),
35+
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
36+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
37+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
38+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
39+
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
40+
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
41+
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
42+
],
43+
)
44+
def test_has_any(sample, types, expected):
45+
assert has_any(sample, *types) is expected
46+
47+
48+
@pytest.mark.parametrize(
49+
("sample", "types", "expected"),
50+
[
51+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
52+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
53+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
54+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
55+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
56+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
57+
(
58+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
59+
(features.Image, features.BoundingBox, features.SegmentationMask),
60+
True,
61+
),
62+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
63+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
64+
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
65+
(
66+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
67+
(features.Image, features.BoundingBox, features.SegmentationMask),
68+
True,
69+
),
70+
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
71+
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
72+
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
73+
(
74+
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
75+
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
76+
True,
77+
),
78+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
79+
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
80+
],
81+
)
82+
def test_has_all(sample, types, expected):
83+
assert has_all(sample, *types) is expected

torchvision/prototype/transforms/_augment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision.prototype.transforms import functional as F
1010

1111
from ._transform import _RandomApplyTransform
12-
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
12+
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
1313

1414

1515
class RandomErasing(_RandomApplyTransform):
@@ -105,7 +105,9 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None:
105105

106106
def forward(self, *inpts: Any) -> Any:
107107
sample = inpts if len(inpts) > 1 else inpts[0]
108-
if not has_all(sample, features.Image, features.OneHotLabel):
108+
if not (
109+
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)
110+
):
109111
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
110112
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
111113
raise TypeError(

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,10 +719,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
719719

720720
def forward(self, *inputs: Any) -> Any:
721721
sample = inputs if len(inputs) > 1 else inputs[0]
722-
# TODO: Allow image to be a torch.Tensor
723722
if not (
724723
has_all(sample, features.BoundingBox)
725-
and has_any(sample, PIL.Image.Image, features.Image)
724+
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
726725
and has_any(sample, features.Label, features.OneHotLabel)
727726
):
728727
raise TypeError(

torchvision/prototype/transforms/_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Tuple, Type, Union
1+
from typing import Any, Callable, Tuple, Type, Union
22

33
import PIL.Image
44
import torch
@@ -39,14 +39,24 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
3939
return channels, height, width
4040

4141

42-
def has_any(sample: Any, *types: Type) -> bool:
42+
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
4343
flat_sample, _ = tree_flatten(sample)
44-
return any(issubclass(type(obj), types) for obj in flat_sample)
44+
for type_or_check in types_or_checks:
45+
for obj in flat_sample:
46+
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
47+
return True
48+
return False
4549

4650

47-
def has_all(sample: Any, *types: Type) -> bool:
51+
def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
4852
flat_sample, _ = tree_flatten(sample)
49-
return not bool(set(types) - set([type(obj) for obj in flat_sample]))
53+
for type_or_check in types_or_checks:
54+
for obj in flat_sample:
55+
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
56+
break
57+
else:
58+
return False
59+
return True
5060

5161

5262
def is_simple_tensor(inpt: Any) -> bool:

0 commit comments

Comments
 (0)