Skip to content

Commit 678a4b7

Browse files
jdsgomesvfdev-5
authored andcommitted
[fbsync] Make prototype F JIT-scriptable (#6584)
Summary: * Improve existing low kernel test. * Add new midlevel jit-scriptability test (failing). * Remove duplicate aliases from kernel tests. * Fixing colour kernels. * Fixing deprecated kernels. * fix mypy * Silence mypy instead of fixing to avoid performance penalty * Fixing augment kernels. * Fixing augment meta. * Remove is_tracing calls. * Add fake ImageType and DType * Fixing type conversion kernels. * Fixing misc kernels. * partial fix geometry * Remove mutable default from `_pad_with_vector_fill()` + all other unnecessary defaults. * Fix geometry ops * Fixing tests * Removed xfail for jit tests on midlevel ops Reviewed By: NicolasHug Differential Revision: D39765297 fbshipit-source-id: 50ec9dc9d9e2f9c8dab6ab01337e01643dc0ab64 Co-authored-by: vfdev-5 <[email protected]>
1 parent 331a773 commit 678a4b7

18 files changed

+363
-264
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,21 @@ def sample_inputs(self, *types):
113113
features.Mask: F.pad_mask,
114114
},
115115
),
116-
DispatcherInfo(
117-
F.perspective,
118-
kernels={
119-
features.Image: F.perspective_image_tensor,
120-
features.BoundingBox: F.perspective_bounding_box,
121-
features.Mask: F.perspective_mask,
122-
},
123-
),
116+
# FIXME:
117+
# RuntimeError: perspective() is missing value for argument 'startpoints'.
118+
# Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints,
119+
# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
120+
# Union(float[], float, int, NoneType) fill=None) -> Tensor
121+
#
122+
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
123+
# DispatcherInfo(
124+
# F.perspective,
125+
# kernels={
126+
# features.Image: F.perspective_image_tensor,
127+
# features.BoundingBox: F.perspective_bounding_box,
128+
# features.Mask: F.perspective_mask,
129+
# },
130+
# ),
124131
DispatcherInfo(
125132
F.center_crop,
126133
kernels={

test/test_prototype_transforms.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ def test__transform(self, padding, fill, padding_mode, mocker):
376376
inpt = mocker.MagicMock(spec=features.Image)
377377
_ = transform(inpt)
378378

379+
fill = transforms.functional._geometry._convert_fill_arg(fill)
380+
if isinstance(padding, tuple):
381+
padding = list(padding)
379382
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
380383

381384
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
@@ -389,14 +392,17 @@ def test__transform_image_mask(self, fill, mocker):
389392
_ = transform(inpt)
390393

391394
if isinstance(fill, int):
395+
fill = transforms.functional._geometry._convert_fill_arg(fill)
392396
calls = [
393397
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
394398
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
395399
]
396400
else:
401+
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
402+
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
397403
calls = [
398-
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
399-
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
404+
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
405+
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
400406
]
401407
fn.assert_has_calls(calls)
402408

@@ -447,6 +453,7 @@ def test__transform(self, fill, side_range, mocker):
447453
torch.rand(1) # random apply changes random state
448454
params = transform._get_params(inpt)
449455

456+
fill = transforms.functional._geometry._convert_fill_arg(fill)
450457
fn.assert_called_once_with(inpt, **params, fill=fill)
451458

452459
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
@@ -465,14 +472,17 @@ def test__transform_image_mask(self, fill, mocker):
465472
params = transform._get_params(inpt)
466473

467474
if isinstance(fill, int):
475+
fill = transforms.functional._geometry._convert_fill_arg(fill)
468476
calls = [
469477
mocker.call(image, **params, fill=fill),
470478
mocker.call(mask, **params, fill=fill),
471479
]
472480
else:
481+
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
482+
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
473483
calls = [
474-
mocker.call(image, **params, fill=fill[type(image)]),
475-
mocker.call(mask, **params, fill=fill[type(mask)]),
484+
mocker.call(image, **params, fill=fill_img),
485+
mocker.call(mask, **params, fill=fill_mask),
476486
]
477487
fn.assert_has_calls(calls)
478488

@@ -533,6 +543,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
533543
torch.manual_seed(12)
534544
params = transform._get_params(inpt)
535545

546+
fill = transforms.functional._geometry._convert_fill_arg(fill)
536547
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
537548

538549
@pytest.mark.parametrize("angle", [34, -87])
@@ -670,6 +681,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
670681
torch.manual_seed(12)
671682
params = transform._get_params(inpt)
672683

684+
fill = transforms.functional._geometry._convert_fill_arg(fill)
673685
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
674686

675687

@@ -917,6 +929,7 @@ def test__transform(self, distortion_scale, mocker):
917929
torch.rand(1) # random apply changes random state
918930
params = transform._get_params(inpt)
919931

932+
fill = transforms.functional._geometry._convert_fill_arg(fill)
920933
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
921934

922935

@@ -986,6 +999,7 @@ def test__transform(self, alpha, sigma, mocker):
986999
transform._get_params = mocker.MagicMock()
9871000
_ = transform(inpt)
9881001
params = transform._get_params(inpt)
1002+
fill = transforms.functional._geometry._convert_fill_arg(fill)
9891003
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
9901004

9911005

@@ -1609,6 +1623,7 @@ def test__transform(self, mocker, needs):
16091623
if not needs_crop:
16101624
assert args[0] is inpt_sentinel
16111625
assert args[1] is padding_sentinel
1626+
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
16121627
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
16131628
else:
16141629
mock_pad.assert_not_called()

test/test_prototype_transforms_dispatchers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
class TestCommon:
12-
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
1312
@pytest.mark.parametrize(
1413
("info", "args_kwargs"),
1514
[

test/test_prototype_transforms_functional.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -407,27 +407,74 @@ def erase_image_tensor():
407407
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
408408

409409

410+
_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"}
411+
412+
413+
def _distinct_callables(callable_names):
414+
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API
415+
remove = set()
416+
distinct = set()
417+
for name in callable_names:
418+
item = F.__dict__[name]
419+
if item not in distinct:
420+
distinct.add(item)
421+
else:
422+
remove.add(name)
423+
callable_names -= remove
424+
425+
# create tuple and sort by name
426+
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0])
427+
428+
429+
def _get_distinct_kernels():
430+
kernel_names = {
431+
name
432+
for name, f in F.__dict__.items()
433+
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES)
434+
}
435+
return _distinct_callables(kernel_names)
436+
437+
438+
def _get_distinct_midlevels():
439+
midlevel_names = {
440+
name
441+
for name, f in F.__dict__.items()
442+
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES)
443+
}
444+
return _distinct_callables(midlevel_names)
445+
446+
410447
@pytest.mark.parametrize(
411448
"kernel",
412449
[
413450
pytest.param(kernel, id=name)
414-
for name, kernel in F.__dict__.items()
415-
if not name.startswith("_")
416-
and callable(kernel)
417-
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"})
418-
and "pil" not in name
419-
and name
451+
for name, kernel in _get_distinct_kernels()
452+
if not name.endswith("_image_pil") and name not in {"to_image_tensor"}
453+
],
454+
)
455+
def test_scriptable_kernel(kernel):
456+
jit.script(kernel) # TODO: pass data through it
457+
458+
459+
@pytest.mark.parametrize(
460+
"midlevel",
461+
[
462+
pytest.param(midlevel, id=name)
463+
for name, midlevel in _get_distinct_midlevels()
464+
if name
420465
not in {
421-
"to_image_tensor",
422-
"get_num_channels",
423-
"get_spatial_size",
424-
"get_image_num_channels",
425-
"get_image_size",
466+
"InterpolationMode",
467+
"decode_image_with_pil",
468+
"decode_video_with_av",
469+
"pil_to_tensor",
470+
"to_grayscale",
471+
"to_pil_image",
472+
"to_tensor",
426473
}
427474
],
428475
)
429-
def test_scriptable(kernel):
430-
jit.script(kernel)
476+
def test_scriptable_midlevel(midlevel):
477+
jit.script(midlevel) # TODO: pass data through it
431478

432479

433480
# Test below is intended to test mid-level op vs low-level ops it calls
@@ -439,8 +486,8 @@ def test_scriptable(kernel):
439486
[
440487
pytest.param(func, id=name)
441488
for name, func in F.__dict__.items()
442-
if not name.startswith("_")
443-
and callable(func)
489+
if not name.startswith("_") and callable(func)
490+
# TODO: remove aliases
444491
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
445492
and name
446493
not in {
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._bounding_box import BoundingBox, BoundingBoxFormat
22
from ._encoded import EncodedData, EncodedImage, EncodedVideo
3-
from ._feature import _Feature, is_simple_tensor
4-
from ._image import ColorSpace, Image
3+
from ._feature import _Feature, DType, is_simple_tensor
4+
from ._image import ColorSpace, Image, ImageType
55
from ._label import Label, OneHotLabel
66
from ._mask import Mask

torchvision/prototype/features/_feature.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
F = TypeVar("F", bound="_Feature")
1111

1212

13+
# Due to torch.jit.script limitation we keep DType as torch.Tensor
14+
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
15+
DType = torch.Tensor
16+
17+
1318
def is_simple_tensor(inpt: Any) -> bool:
1419
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
1520

torchvision/prototype/features/_image.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from ._feature import _Feature
1313

1414

15+
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
16+
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
17+
ImageType = torch.Tensor
18+
19+
1520
class ColorSpace(StrEnum):
1621
OTHER = StrEnum.auto()
1722
GRAY = StrEnum.auto()
@@ -32,6 +37,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
3237
else:
3338
return cls.OTHER
3439

40+
@staticmethod
41+
def from_tensor_shape(shape: List[int]) -> ColorSpace:
42+
return _from_tensor_shape(shape)
43+
44+
45+
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
46+
# Needed as a standalone method for JIT
47+
ndim = len(shape)
48+
if ndim < 2:
49+
return ColorSpace.OTHER
50+
elif ndim == 2:
51+
return ColorSpace.GRAY
52+
53+
num_channels = shape[-3]
54+
if num_channels == 1:
55+
return ColorSpace.GRAY
56+
elif num_channels == 2:
57+
return ColorSpace.GRAY_ALPHA
58+
elif num_channels == 3:
59+
return ColorSpace.RGB
60+
elif num_channels == 4:
61+
return ColorSpace.RGB_ALPHA
62+
else:
63+
return ColorSpace.OTHER
64+
3565

3666
class Image(_Feature):
3767
color_space: ColorSpace
@@ -53,7 +83,7 @@ def __new__(
5383
image = super().__new__(cls, data, requires_grad=requires_grad)
5484

5585
if color_space is None:
56-
color_space = cls.guess_color_space(image)
86+
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
5787
if color_space == ColorSpace.OTHER:
5888
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
5989
elif isinstance(color_space, str):
@@ -83,25 +113,6 @@ def image_size(self) -> Tuple[int, int]:
83113
def num_channels(self) -> int:
84114
return self.shape[-3]
85115

86-
@staticmethod
87-
def guess_color_space(data: torch.Tensor) -> ColorSpace:
88-
if data.ndim < 2:
89-
return ColorSpace.OTHER
90-
elif data.ndim == 2:
91-
return ColorSpace.GRAY
92-
93-
num_channels = data.shape[-3]
94-
if num_channels == 1:
95-
return ColorSpace.GRAY
96-
elif num_channels == 2:
97-
return ColorSpace.GRAY_ALPHA
98-
elif num_channels == 3:
99-
return ColorSpace.RGB
100-
elif num_channels == 4:
101-
return ColorSpace.RGB_ALPHA
102-
else:
103-
return ColorSpace.OTHER
104-
105116
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
106117
if isinstance(color_space, str):
107118
color_space = ColorSpace.from_str(color_space.upper())

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,10 @@ def _apply_image_transform(
7272

7373
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
7474
# So, we have to put fill as None if fill == 0
75-
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
75+
# This is due to BC with stable API which has fill = None by default
76+
fill_ = F._geometry._convert_fill_arg(fill)
7677
if isinstance(fill, int) and fill == 0:
7778
fill_ = None
78-
else:
79-
fill_ = fill
8079

8180
if transform_id == "Identity":
8281
return image

0 commit comments

Comments
 (0)