Skip to content

Add typehints for torchvision.io #2543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 14, 2020
6 changes: 5 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pretty = True

;ignore_errors = True

[mypy-torchvision.io.*]
[mypy-torchvision.io._video_opt.*]

ignore_errors = True

Expand Down Expand Up @@ -51,3 +51,7 @@ ignore_missing_imports = True
[mypy-accimage.*]

ignore_missing_imports = True

[mypy-av.*]

ignore_missing_imports = True
21 changes: 8 additions & 13 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
from torch import nn, Tensor

import os
import os.path as osp
import importlib
import importlib.machinery

_HAS_IMAGE_OPT = False

Expand All @@ -15,7 +14,7 @@
importlib.machinery.EXTENSION_SUFFIXES
)

extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type]
ext_specs = extfinder.find_spec("image")
if ext_specs is not None:
torch.ops.load_library(ext_specs.origin)
Expand All @@ -24,8 +23,7 @@
pass


def decode_png(input):
# type: (Tensor) -> Tensor
def decode_png(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Expand All @@ -37,7 +35,7 @@ def decode_png(input):
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1:
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")

if not input.dtype == torch.uint8:
Expand All @@ -46,8 +44,7 @@ def decode_png(input):
return output


def read_png(path):
# type: (str) -> Tensor
def read_png(path: str) -> torch.Tensor:
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Expand All @@ -68,8 +65,7 @@ def read_png(path):
return decode_png(data)


def decode_jpeg(input):
# type: (Tensor) -> Tensor
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Expand All @@ -79,7 +75,7 @@ def decode_jpeg(input):
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1:
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")

if not input.dtype == torch.uint8:
Expand All @@ -89,8 +85,7 @@ def decode_jpeg(input):
return output


def read_jpeg(path):
# type: (str) -> Tensor
def read_jpeg(path: str) -> torch.Tensor:
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Expand Down
51 changes: 33 additions & 18 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import re
import warnings
from typing import List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -35,12 +35,12 @@
)


def _check_av_available():
def _check_av_available() -> None:
if isinstance(av, Exception):
raise av


def _av_available():
def _av_available() -> bool:
return not isinstance(av, Exception)


Expand All @@ -49,7 +49,13 @@ def _av_available():
_GC_COLLECTION_INTERVAL = 10


def write_video(filename, video_array, fps: Union[int, float], video_codec="libx264", options=None):
def write_video(
filename: str,
video_array: torch.Tensor,
fps: float,
video_codec: str = "libx264",
options: Optional[Dict[str, Any]] = None,
) -> None:
"""
Writes a 4d tensor in [T, H, W, C] format in a video file

Expand Down Expand Up @@ -89,8 +95,13 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx


def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
):
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> List["av.frame.Frame"]:
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
Expand Down Expand Up @@ -166,7 +177,9 @@ def _read_from_stream(
return result


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
def _align_audio_frames(
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
) -> torch.Tensor:
start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
Expand All @@ -179,7 +192,9 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
def read_video(
filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts"
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
Expand Down Expand Up @@ -260,16 +275,16 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
# TODO raise a warning?
pass

vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]

if vframes:
vframes = torch.as_tensor(np.stack(vframes))
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

if aframes:
aframes = np.concatenate(aframes, 1)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
Expand All @@ -278,7 +293,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
return vframes, aframes, info


def _can_read_timestamps_from_packets(container):
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
extradata = container.streams[0].codec_context.extradata
if extradata is None:
return False
Expand All @@ -287,15 +302,15 @@ def _can_read_timestamps_from_packets(container):
return False


def _decode_video_timestamps(container):
def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
return [x.pts for x in container.decode(video=0) if x.pts is not None]


def read_video_timestamps(filename, pts_unit="pts"):
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
"""
List the video frames timestamps.

Expand All @@ -313,7 +328,7 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts : List[int] if pts_unit = 'pts'
List[Fraction] if pts_unit = 'sec'
presentation timestamps for each one of the frames in the video.
video_fps : int
video_fps : float, optional
the frame rate for the video

"""
Expand Down