Skip to content

Adding support of Video to remaining Transforms and Kernels #6724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,14 @@
)
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video
from ._video import (
ImageOrVideoType,
ImageOrVideoTypeJIT,
LegacyVideoType,
LegacyVideoTypeJIT,
TensorImageOrVideoType,
TensorImageOrVideoTypeJIT,
Video,
VideoType,
VideoTypeJIT,
)
1 change: 1 addition & 0 deletions torchvision/prototype/features/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor

# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods
ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) ->
return inpt


# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def forward(self, *inputs: Any) -> Any:
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

orig_dims = list(image_or_video.shape)
batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims)
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
Expand Down Expand Up @@ -520,7 +521,7 @@ def forward(self, *inputs: Any) -> Any:
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _permute_channels(
output = inpt[..., permutation, :, :]

if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]

elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output)
Expand Down
16 changes: 8 additions & 8 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str,


class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)

def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
Expand All @@ -52,15 +52,15 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
super().__init__()
self.num_output_channels = num_output_channels

def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
return output


class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)

def __init__(self, p: float = 0.1) -> None:
warnings.warn(
Expand All @@ -81,8 +81,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)

def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
return output
27 changes: 17 additions & 10 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ class FiveCrop(Transform):
"""
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
... images, labels = sample
... batch_size = len(images)
... images = features.Image.wrap_like(images[0], torch.stack(images))
... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
... return images, labels
... return images_or_videos, labels
...
>>> image = features.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0)
Expand All @@ -172,15 +173,21 @@ class FiveCrop(Transform):
torch.Size([5])
"""

_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)

def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

def _transform(
self, inpt: features.ImageType, params: Dict[str, Any]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
self, inpt: features.ImageOrVideoType, params: Dict[str, Any]
) -> Tuple[
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
]:
Comment on lines +184 to +190
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore if mypy is happy

This is not accurate. We don't have Tuple[features.ImageOrVideoType, ...] here, but rather Union[Tuple[features.ImageType],...], Tuple[features.VideoType, ...]]. Meaning, the type will not vary inside the returned tuple. We either get a tuple of images or a tuple of videos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why IMO we should avoid features.ImageOrVideoType and instead define it as Union[Image, Video]. Since this will be fixed in a follow up, there is no point messing with mypy (which is happy) here. I'll implement this on a follow up.

return F.five_crop(inpt, self.size)

def forward(self, *inputs: Any) -> Any:
Expand All @@ -194,14 +201,14 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""

_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)

def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip

def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)

def forward(self, *inputs: Any) -> Any:
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat


class ConvertImageDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image)
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype

def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
def _transform(
self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]
) -> features.TensorImageOrVideoType:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return (
output
if features.is_simple_tensor(inpt)
else features.Image.wrap_like(inpt, output) # type: ignore[arg-type]
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)


# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
five_crop_video,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip,
horizontal_flip_bounding_box,
Expand Down Expand Up @@ -136,6 +137,7 @@
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
ten_crop_video,
vertical_flip,
vertical_flip_bounding_box,
vertical_flip_image_pil,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def erase(
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
11 changes: 7 additions & 4 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, List
from typing import Any, List, Union

import PIL.Image
import torch
Expand All @@ -22,10 +22,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)


def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT:
def rgb_to_grayscale(
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
if isinstance(inpt, torch.Tensor)
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
else None
)

Expand Down Expand Up @@ -56,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)


def get_image_size(inpt: features.ImageTypeJIT) -> List[int]:
def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
Expand Down
33 changes: 25 additions & 8 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,16 +1376,27 @@ def five_crop_image_pil(
return tl, tr, bl, br, center


def five_crop_video(
video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size)


def five_crop(
inpt: features.ImageTypeJIT, size: List[int]
inpt: features.ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[
features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop`
# TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment]
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
output = tmp # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
Expand Down Expand Up @@ -1418,11 +1429,17 @@ def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: b
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]


def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]:
def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False
) -> List[features.ImageOrVideoTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = [features.Image.wrap_like(inpt, item) for item in output]
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
return output
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
6 changes: 5 additions & 1 deletion torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width]


# TODO: Should we have get_spatial_size_video here? How about masks/bbox etc? What is the criterion for deciding when
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_spatial_size should apply to everything, right? That was the whole reason we have extracted it out, because bounding boxes and masks can provide this information, while num_channels is reserved for images and videos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already updated get_spatial_size to handle all inputs.

I think you are trying to answer a different question from what I ask here. What I think we should discuss is whether there should be specific kernels for each type, unrelated to whether the dispatcher can handle everything. We already have kernels (like erase_video) that aren't necessarily used in the dispatcher. So here I'm asking, what should the convention over providing kernels for individual types should be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, yes I was confused. That is a good question and I don't have an answer for it yet. My gut says that we should stay consistent and provide the kernels just as we do for the other transforms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same feeling here. I'll leave the TODO for the follow up. I think we can answer this on the PR where we switch image_size to spatial_size

# a kernel will be created?


def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt)
Expand Down Expand Up @@ -246,7 +250,7 @@ def convert_color_space(
):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(
Expand Down