Skip to content

Make prototype F JIT-scriptable #6584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
136da06
Improve existing low kernel test.
datumbox Sep 14, 2022
e5dfa68
Add new midlevel jit-scriptability test (failing).
datumbox Sep 14, 2022
6624ab5
Merge branch 'main' into jit/prototype_transforms
datumbox Sep 16, 2022
9369c6e
Remove duplicate aliases from kernel tests.
datumbox Sep 16, 2022
087e916
Fixing colour kernels.
datumbox Sep 16, 2022
f844a5c
Fixing deprecated kernels.
datumbox Sep 16, 2022
0b961ef
fix mypy
datumbox Sep 16, 2022
f75a0df
Silence mypy instead of fixing to avoid performance penalty
datumbox Sep 16, 2022
c286621
Fixing augment kernels.
datumbox Sep 16, 2022
1c2822c
Fixing augment meta.
datumbox Sep 16, 2022
4a3df83
Remove is_tracing calls.
datumbox Sep 16, 2022
663d8c7
Add fake ImageType and DType
datumbox Sep 16, 2022
0d38972
Fixing type conversion kernels.
datumbox Sep 16, 2022
60000fa
Fixing misc kernels.
datumbox Sep 16, 2022
58e5707
partial fix geometry
datumbox Sep 16, 2022
1e301b1
Merge branch 'main' into jit/prototype_transforms
datumbox Sep 16, 2022
2f37578
Merge branch 'main' into jit/prototype_transforms
datumbox Sep 17, 2022
1395149
Remove mutable default from `_pad_with_vector_fill()` + all other unn…
datumbox Sep 17, 2022
96f318d
Merge branch 'main' of github.com:pytorch/vision into jit/prototype_t…
vfdev-5 Sep 19, 2022
87f3567
Fix geometry ops
vfdev-5 Sep 19, 2022
df53707
Fixing tests
vfdev-5 Sep 19, 2022
e0185fa
Merge branch 'main' of github.com:pytorch/vision into jit/prototype_t…
vfdev-5 Sep 19, 2022
3d0e1c9
Merge branch 'main' of github.com:pytorch/vision into jit/prototype_t…
vfdev-5 Sep 19, 2022
d7e516c
Removed xfail for jit tests on midlevel ops
vfdev-5 Sep 19, 2022
7c309eb
Merge branch 'main' into jit/prototype_transforms
datumbox Sep 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,21 @@ def sample_inputs(self, *types):
features.Mask: F.pad_mask,
},
),
DispatcherInfo(
F.perspective,
kernels={
features.Image: F.perspective_image_tensor,
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
),
# FIXME:
# RuntimeError: perspective() is missing value for argument 'startpoints'.
# Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints,
# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
# Union(float[], float, int, NoneType) fill=None) -> Tensor
#
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if the signatures actually diverge, then this is the issue? I'm ok disabling this here for now and look at it later.

Still, is this something we want? Shouldn't the public kernels be in sync with the dispatcher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @vfdev-5 made this change. As discussed previously there are other places where we need to align the signatures and this can happen on a follow up PR to avoid making this too long.

Personally I think it's worth aligning the signatures unless there is a good reason not to (perhaps the interpolation default value is one exception?). I'm open to discussing this and I think we should agree on the policy soon.

# DispatcherInfo(
# F.perspective,
# kernels={
# features.Image: F.perspective_image_tensor,
# features.BoundingBox: F.perspective_bounding_box,
# features.Mask: F.perspective_mask,
# },
# ),
DispatcherInfo(
F.center_crop,
kernels={
Expand Down
23 changes: 19 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def test__transform(self, padding, fill, padding_mode, mocker):
inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)

@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
Expand All @@ -389,14 +392,17 @@ def test__transform_image_mask(self, fill, mocker):
_ = transform(inpt)

if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
]
fn.assert_has_calls(calls)

Expand Down Expand Up @@ -447,6 +453,7 @@ def test__transform(self, fill, side_range, mocker):
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)

@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
Expand All @@ -465,14 +472,17 @@ def test__transform_image_mask(self, fill, mocker):
params = transform._get_params(inpt)

if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill[type(image)]),
mocker.call(mask, **params, fill=fill[type(mask)]),
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
]
fn.assert_has_calls(calls)

Expand Down Expand Up @@ -533,6 +543,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
torch.manual_seed(12)
params = transform._get_params(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)

@pytest.mark.parametrize("angle", [34, -87])
Expand Down Expand Up @@ -670,6 +681,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
torch.manual_seed(12)
params = transform._get_params(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)


Expand Down Expand Up @@ -917,6 +929,7 @@ def test__transform(self, distortion_scale, mocker):
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


Expand Down Expand Up @@ -986,6 +999,7 @@ def test__transform(self, alpha, sigma, mocker):
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


Expand Down Expand Up @@ -1609,6 +1623,7 @@ def test__transform(self, mocker, needs):
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
Expand Down
1 change: 0 additions & 1 deletion test/test_prototype_transforms_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class TestCommon:
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@pytest.mark.parametrize(
("info", "args_kwargs"),
[
Expand Down
77 changes: 62 additions & 15 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,27 +407,74 @@ def erase_image_tensor():
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))


_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"}


def _distinct_callables(callable_names):
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API
remove = set()
distinct = set()
for name in callable_names:
item = F.__dict__[name]
if item not in distinct:
distinct.add(item)
else:
remove.add(name)
callable_names -= remove

# create tuple and sort by name
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0])


def _get_distinct_kernels():
kernel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(kernel_names)


def _get_distinct_midlevels():
midlevel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(midlevel_names)


@pytest.mark.parametrize(
"kernel",
[
pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items()
if not name.startswith("_")
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"})
and "pil" not in name
and name
for name, kernel in _get_distinct_kernels()
if not name.endswith("_image_pil") and name not in {"to_image_tensor"}
],
)
def test_scriptable_kernel(kernel):
jit.script(kernel) # TODO: pass data through it


@pytest.mark.parametrize(
"midlevel",
[
pytest.param(midlevel, id=name)
for name, midlevel in _get_distinct_midlevels()
if name
not in {
"to_image_tensor",
"get_num_channels",
"get_spatial_size",
"get_image_num_channels",
"get_image_size",
"InterpolationMode",
"decode_image_with_pil",
"decode_video_with_av",
"pil_to_tensor",
"to_grayscale",
"to_pil_image",
"to_tensor",
}
],
)
def test_scriptable(kernel):
jit.script(kernel)
def test_scriptable_midlevel(midlevel):
jit.script(midlevel) # TODO: pass data through it
Comment on lines +476 to +477
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Isn't that superseded by the tests I've added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? I noticed that at one point early in the development, one of the kernels had a return type Any and this test didn't complaint. This is why usually to be safe, we take the strategy of:

  1. JIT-scripting
  2. Passing data through the kernel and compare it with the non-JIT version
  3. Serialize/deserialize the method and confirm it still returns the right value.

See _check_jit_scriptable() from test_models.py for more info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox in test/test_prototype_transforms_dispatchers.py the test should script and execute midlevel op, so I think we can remove this test_scriptable_midlevel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll cleanup.



# Test below is intended to test mid-level op vs low-level ops it calls
Expand All @@ -439,8 +486,8 @@ def test_scriptable(kernel):
[
pytest.param(func, id=name)
for name, func in F.__dict__.items()
if not name.startswith("_")
and callable(func)
if not name.startswith("_") and callable(func)
# TODO: remove aliases
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
and name
not in {
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image
from ._feature import _Feature, DType, is_simple_tensor
from ._image import ColorSpace, Image, ImageType
from ._label import Label, OneHotLabel
from ._mask import Mask
5 changes: 5 additions & 0 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
F = TypeVar("F", bound="_Feature")


# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor


def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)

Expand Down
51 changes: 31 additions & 20 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from ._feature import _Feature


# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor


class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAY = StrEnum.auto()
Expand All @@ -32,6 +37,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
else:
return cls.OTHER

@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes more sense to be on the Enum than the method. At any case it was moved because JIT didn't like it to be a class method. We don't have to keep it in the enum, we just need to have it as a standalone method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can't even have classmethods on objects that are not covered by JIT? 🤯

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I intentionally put the term "guess" in the name, since the number of channels is not sufficient to pick the right colorspace. For example, CMYK also has 4 channels, but would be classified as RGBA. However, this is not a problem now since we don't support it yet and maybe never will.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue is that the specific class is definitely not JIT-scriptable. JIT complains about the tensor_contents: Any = None value on the __repr__. Perhaps we can remove this?

I'm happy to make any changes on the name. OR try a different approach with JIT. Let me wrap up the rest of the kernels to see where we are and we can try a couple of options.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it works as is, I wouldn't mess with this class any further. If it is needed we can remove the : Any though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I'm also really flexible. I don't have context why the tensor_contents is introduced and what's supposed to be. Would you like to send a PR once this is merged to see if you can make it work in the original location?

return _from_tensor_shape(shape)


def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY

num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER


class Image(_Feature):
color_space: ColorSpace
Expand All @@ -53,7 +83,7 @@ def __new__(
image = super().__new__(cls, data, requires_grad=requires_grad)

if color_space is None:
color_space = cls.guess_color_space(image)
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
Expand Down Expand Up @@ -83,25 +113,6 @@ def image_size(self) -> Tuple[int, int]:
def num_channels(self) -> int:
return self.shape[-3]

@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAY

num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER

def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
Expand Down
5 changes: 2 additions & 3 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,10 @@ def _apply_image_transform(

# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider making _convert_fill_arg() public on the future as this is regularly used in Transforms as a utility method.

if isinstance(fill, int) and fill == 0:
fill_ = None
else:
fill_ = fill

if transform_id == "Identity":
return image
Expand Down
Loading