-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Make prototype F
JIT-scriptable
#6584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
136da06
e5dfa68
6624ab5
9369c6e
087e916
f844a5c
0b961ef
f75a0df
c286621
1c2822c
4a3df83
663d8c7
0d38972
60000fa
58e5707
1e301b1
2f37578
1395149
96f318d
87f3567
df53707
e0185fa
3d0e1c9
d7e516c
7c309eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ | |
|
||
|
||
class TestCommon: | ||
@pytest.mark.xfail(reason="dispatchers are currently not scriptable") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉 |
||
@pytest.mark.parametrize( | ||
("info", "args_kwargs"), | ||
[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -407,27 +407,74 @@ def erase_image_tensor(): | |
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) | ||
|
||
|
||
_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"} | ||
|
||
|
||
def _distinct_callables(callable_names): | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API | ||
remove = set() | ||
distinct = set() | ||
for name in callable_names: | ||
item = F.__dict__[name] | ||
if item not in distinct: | ||
distinct.add(item) | ||
else: | ||
remove.add(name) | ||
callable_names -= remove | ||
|
||
# create tuple and sort by name | ||
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0]) | ||
|
||
|
||
def _get_distinct_kernels(): | ||
kernel_names = { | ||
name | ||
for name, f in F.__dict__.items() | ||
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES) | ||
} | ||
return _distinct_callables(kernel_names) | ||
|
||
|
||
def _get_distinct_midlevels(): | ||
midlevel_names = { | ||
name | ||
for name, f in F.__dict__.items() | ||
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
return _distinct_callables(midlevel_names) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kernel", | ||
[ | ||
pytest.param(kernel, id=name) | ||
for name, kernel in F.__dict__.items() | ||
if not name.startswith("_") | ||
and callable(kernel) | ||
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"}) | ||
and "pil" not in name | ||
and name | ||
for name, kernel in _get_distinct_kernels() | ||
if not name.endswith("_image_pil") and name not in {"to_image_tensor"} | ||
], | ||
) | ||
def test_scriptable_kernel(kernel): | ||
jit.script(kernel) # TODO: pass data through it | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"midlevel", | ||
[ | ||
pytest.param(midlevel, id=name) | ||
for name, midlevel in _get_distinct_midlevels() | ||
if name | ||
not in { | ||
"to_image_tensor", | ||
"get_num_channels", | ||
"get_spatial_size", | ||
"get_image_num_channels", | ||
"get_image_size", | ||
"InterpolationMode", | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"decode_image_with_pil", | ||
"decode_video_with_av", | ||
"pil_to_tensor", | ||
"to_grayscale", | ||
"to_pil_image", | ||
"to_tensor", | ||
} | ||
], | ||
) | ||
def test_scriptable(kernel): | ||
jit.script(kernel) | ||
def test_scriptable_midlevel(midlevel): | ||
jit.script(midlevel) # TODO: pass data through it | ||
Comment on lines
+476
to
+477
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vfdev-5 Isn't that superseded by the tests I've added? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe? I noticed that at one point early in the development, one of the kernels had a return type
See There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox in test/test_prototype_transforms_dispatchers.py the test should script and execute midlevel op, so I think we can remove this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll cleanup. |
||
|
||
|
||
# Test below is intended to test mid-level op vs low-level ops it calls | ||
|
@@ -439,8 +486,8 @@ def test_scriptable(kernel): | |
[ | ||
pytest.param(func, id=name) | ||
for name, func in F.__dict__.items() | ||
if not name.startswith("_") | ||
and callable(func) | ||
if not name.startswith("_") and callable(func) | ||
# TODO: remove aliases | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"}) | ||
and name | ||
not in { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from ._bounding_box import BoundingBox, BoundingBoxFormat | ||
from ._encoded import EncodedData, EncodedImage, EncodedVideo | ||
from ._feature import _Feature, is_simple_tensor | ||
from ._image import ColorSpace, Image | ||
from ._feature import _Feature, DType, is_simple_tensor | ||
from ._image import ColorSpace, Image, ImageType | ||
from ._label import Label, OneHotLabel | ||
from ._mask import Mask |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,11 @@ | |
from ._feature import _Feature | ||
|
||
|
||
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor | ||
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image] | ||
ImageType = torch.Tensor | ||
|
||
|
||
class ColorSpace(StrEnum): | ||
OTHER = StrEnum.auto() | ||
GRAY = StrEnum.auto() | ||
|
@@ -32,6 +37,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace: | |
else: | ||
return cls.OTHER | ||
|
||
@staticmethod | ||
def from_tensor_shape(shape: List[int]) -> ColorSpace: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this makes more sense to be on the Enum than the method. At any case it was moved because JIT didn't like it to be a class method. We don't have to keep it in the enum, we just need to have it as a standalone method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we can't even have classmethods on objects that are not covered by JIT? 🤯 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plus, I intentionally put the term "guess" in the name, since the number of channels is not sufficient to pick the right colorspace. For example, CMYK also has 4 channels, but would be classified as RGBA. However, this is not a problem now since we don't support it yet and maybe never will. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the issue is that the specific class is definitely not JIT-scriptable. JIT complains about the I'm happy to make any changes on the name. OR try a different approach with JIT. Let me wrap up the rest of the kernels to see where we are and we can try a couple of options. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it works as is, I wouldn't mess with this class any further. If it is needed we can remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pmeier I'm also really flexible. I don't have context why the |
||
return _from_tensor_shape(shape) | ||
|
||
|
||
def _from_tensor_shape(shape: List[int]) -> ColorSpace: | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Needed as a standalone method for JIT | ||
ndim = len(shape) | ||
if ndim < 2: | ||
return ColorSpace.OTHER | ||
elif ndim == 2: | ||
return ColorSpace.GRAY | ||
|
||
num_channels = shape[-3] | ||
if num_channels == 1: | ||
return ColorSpace.GRAY | ||
elif num_channels == 2: | ||
return ColorSpace.GRAY_ALPHA | ||
elif num_channels == 3: | ||
return ColorSpace.RGB | ||
elif num_channels == 4: | ||
return ColorSpace.RGB_ALPHA | ||
else: | ||
return ColorSpace.OTHER | ||
|
||
|
||
class Image(_Feature): | ||
color_space: ColorSpace | ||
|
@@ -53,7 +83,7 @@ def __new__( | |
image = super().__new__(cls, data, requires_grad=requires_grad) | ||
|
||
if color_space is None: | ||
color_space = cls.guess_color_space(image) | ||
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type] | ||
if color_space == ColorSpace.OTHER: | ||
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") | ||
elif isinstance(color_space, str): | ||
|
@@ -83,25 +113,6 @@ def image_size(self) -> Tuple[int, int]: | |
def num_channels(self) -> int: | ||
return self.shape[-3] | ||
|
||
@staticmethod | ||
def guess_color_space(data: torch.Tensor) -> ColorSpace: | ||
if data.ndim < 2: | ||
return ColorSpace.OTHER | ||
elif data.ndim == 2: | ||
return ColorSpace.GRAY | ||
|
||
num_channels = data.shape[-3] | ||
if num_channels == 1: | ||
return ColorSpace.GRAY | ||
elif num_channels == 2: | ||
return ColorSpace.GRAY_ALPHA | ||
elif num_channels == 3: | ||
return ColorSpace.RGB | ||
elif num_channels == 4: | ||
return ColorSpace.RGB_ALPHA | ||
else: | ||
return ColorSpace.OTHER | ||
|
||
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image: | ||
if isinstance(color_space, str): | ||
color_space = ColorSpace.from_str(color_space.upper()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,11 +72,10 @@ def _apply_image_transform( | |
|
||
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 | ||
# So, we have to put fill as None if fill == 0 | ||
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]] | ||
# This is due to BC with stable API which has fill = None by default | ||
fill_ = F._geometry._convert_fill_arg(fill) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to consider making |
||
if isinstance(fill, int) and fill == 0: | ||
fill_ = None | ||
else: | ||
fill_ = fill | ||
|
||
if transform_id == "Identity": | ||
return image | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if the signatures actually diverge, then this is the issue? I'm ok disabling this here for now and look at it later.
Still, is this something we want? Shouldn't the public kernels be in sync with the dispatcher?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe @vfdev-5 made this change. As discussed previously there are other places where we need to align the signatures and this can happen on a follow up PR to avoid making this too long.
Personally I think it's worth aligning the signatures unless there is a good reason not to (perhaps the interpolation default value is one exception?). I'm open to discussing this and I think we should agree on the policy soon.