Skip to content

add Video feature and kernels #6667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4874907
add video feature
pmeier Sep 29, 2022
a1b00b4
add video kernels
pmeier Sep 29, 2022
e7a229c
add video testing utils
pmeier Sep 29, 2022
5d8b8b6
add one kernel info
pmeier Sep 29, 2022
2380f10
Merge branch 'main' into video
pmeier Oct 4, 2022
a04d667
fix kernel names in Video feature
pmeier Oct 4, 2022
35642b9
use only uint8 for video testing
pmeier Oct 4, 2022
ae59458
require at least 4 dims for Video feature
pmeier Oct 4, 2022
0fb1c35
add TODO for image_size -> spatial_size
pmeier Oct 4, 2022
2d1e560
image -> video in feature constructor
pmeier Oct 4, 2022
91e15b2
introduce new combined images and video type
pmeier Oct 4, 2022
81237fe
add video to transform utils
pmeier Oct 5, 2022
aa26292
fix transforms test
pmeier Oct 5, 2022
93d7556
fix auto augment
pmeier Oct 5, 2022
6df2f0f
Merge branch 'main' into video
pmeier Oct 5, 2022
a99765d
cleanup
pmeier Oct 5, 2022
17ee7f7
Merge branch 'main' into video
pmeier Oct 6, 2022
4506cdf
address review comments
pmeier Oct 6, 2022
36f52dc
add remaining video kernel infos
pmeier Oct 6, 2022
0d2ad96
add batch dimension squashing to some kernels
pmeier Oct 6, 2022
f1e2bfa
fix tests and kernel infos
pmeier Oct 6, 2022
93fc321
add xfails for arbitrary batch sizes on some kernels
pmeier Oct 6, 2022
f843612
Merge branch 'main' into video
pmeier Oct 6, 2022
ad4d424
Merge branch 'main' into video
pmeier Oct 7, 2022
d8945e6
fix test setup
pmeier Oct 7, 2022
1c86193
fix equalize_image_tensor for multi batch dims
pmeier Oct 7, 2022
1c2b615
fix adjust_sharpness_image_tensor for multi batch dims
pmeier Oct 7, 2022
7f3a8b7
Merge branch 'video' of https://github.com/pmeier/vision into video
pmeier Oct 7, 2022
2d7b07d
address review comments
pmeier Oct 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ def _parse_image_size(size, *, name="size"):

def from_loader(loader_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loader = loader_fn(*args, **kwargs)
return loader.load(kwargs.get("device", "cpu"))
return loader.load(device)

return wrapper


def from_loaders(loaders_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loaders = loaders_fn(*args, **kwargs)
for loader in loaders:
yield loader.load(kwargs.get("device", "cpu"))
yield loader.load(device)

return wrapper

Expand Down Expand Up @@ -527,3 +529,42 @@ def make_mask_loaders(


make_masks = from_loaders(make_mask_loaders)


class VideoLoader(ImageLoader):
pass


def make_video_loader(
size="random",
*,
color_space=features.ColorSpace.RGB,
num_frames="random",
extra_dims=(),
dtype=torch.float32,
):
size = _parse_image_size(size)
num_frames = int(torch.randint(1, 6, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-2], dtype=dtype, device=device)
return features.Video(video, color_space=color_space)

return VideoLoader(fn, shape=(*extra_dims, num_frames, *size), dtype=dtype, color_space=color_space)


def make_video_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
):
yield make_video_loader(**params)
17 changes: 16 additions & 1 deletion test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
import torchvision.ops
import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
from prototype_common_utils import (
ArgsKwargs,
make_bounding_box_loaders,
make_image_loaders,
make_mask_loaders,
make_video_loaders,
)
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value

Expand Down Expand Up @@ -126,6 +132,11 @@ def sample_inputs_horizontal_flip_mask():
yield ArgsKwargs(image_loader)


def sample_inputs_horizontal_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -144,6 +155,10 @@ def sample_inputs_horizontal_flip_mask():
F.horizontal_flip_mask,
sample_inputs_fn=sample_inputs_horizontal_flip_mask,
),
KernelInfo(
F.horizontal_flip_video,
sample_inputs_fn=sample_inputs_horizontal_flip_video,
),
]
)

Expand Down
1 change: 1 addition & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_batched_vs_single(self, info, args_kwargs, device):
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
features.Mask: 2,
features.Video: 4,
}.get(feature_type)
if data_dims is None:
raise pytest.UsageError(
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
)
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import Video
232 changes: 232 additions & 0 deletions torchvision/prototype/features/_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

import warnings
from typing import Any, cast, List, Optional, Tuple, Union

import torch
from torchvision.transforms.functional import InterpolationMode

from ._feature import _Feature, FillTypeJIT
from ._image import ColorSpace


class Video(_Feature):
color_space: ColorSpace

def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Video:
data = torch.as_tensor(data, dtype=dtype, device=device) # type: ignore[arg-type]
if data.ndim < 3:
raise ValueError
elif data.ndim == 3:
data = data.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)

if color_space is None:
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
image.color_space = color_space

return image

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)

@classmethod
def new_like(
cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Video:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)

@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))

@property
def num_channels(self) -> int:
return self.shape[-3]

@property
def num_frames(self) -> int:
return self.shape[-4]

def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())

return Video.new_like(
self,
self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)

def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_Video_tensor(self)
return Video.new_like(self, output)

def vertical_flip(self) -> Video:
output = self._F.vertical_flip_Video_tensor(self)
return Video.new_like(self, output)

def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Video:
output = self._F.resize_Video_tensor(
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return Video.new_like(self, output)

def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_Video_tensor(self, top, left, height, width)
return Video.new_like(self, output)

def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_Video_tensor(self, output_size=output_size)
return Video.new_like(self, output)

def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Video:
output = self._F.resized_crop_Video_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
)
return Video.new_like(self, output)

def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_Video_tensor(self, padding, fill=fill, padding_mode=padding_mode)
return Video.new_like(self, output)

def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.rotate_Video_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Video.new_like(self, output)

def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.affine_Video_tensor(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Video.new_like(self, output)

def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.perspective_Video_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
)
return Video.new_like(self, output)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.elastic_Video_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Video.new_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_Video_tensor(self, brightness_factor=brightness_factor)
return Video.new_like(self, output)

def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_Video_tensor(self, saturation_factor=saturation_factor)
return Video.new_like(self, output)

def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_Video_tensor(self, contrast_factor=contrast_factor)
return Video.new_like(self, output)

def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_Video_tensor(self, sharpness_factor=sharpness_factor)
return Video.new_like(self, output)

def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_Video_tensor(self, hue_factor=hue_factor)
return Video.new_like(self, output)

def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_Video_tensor(self, gamma=gamma, gain=gain)
return Video.new_like(self, output)

def posterize(self, bits: int) -> Video:
output = self._F.posterize_Video_tensor(self, bits=bits)
return Video.new_like(self, output)

def solarize(self, threshold: float) -> Video:
output = self._F.solarize_Video_tensor(self, threshold=threshold)
return Video.new_like(self, output)

def autocontrast(self) -> Video:
output = self._F.autocontrast_Video_tensor(self)
return Video.new_like(self, output)

def equalize(self) -> Video:
output = self._F.equalize_Video_tensor(self)
return Video.new_like(self, output)

def invert(self) -> Video:
output = self._F.invert_Video_tensor(self)
return Video.new_like(self, output)

def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_Video_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Video.new_like(self, output)
Loading