Skip to content

Added typing annotations to io/__init__ #4224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 31, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from typing import Any, Dict, Iterator

from ._video_opt import (
Timebase,
Expand Down Expand Up @@ -33,13 +34,13 @@

if _HAS_VIDEO_OPT:

def _has_video_opt():
def _has_video_opt() -> bool:
return True


else:

def _has_video_opt():
def _has_video_opt() -> bool:
return False


Expand Down Expand Up @@ -99,7 +100,7 @@ class VideoReader:
Currently available options include ``['video', 'audio']``
"""

def __init__(self, path, stream="video"):
def __init__(self, path: str, stream: str = "video") -> None:
if not _has_video_opt():
raise RuntimeError(
"Not compiled with video_reader support, "
Expand All @@ -109,7 +110,7 @@ def __init__(self, path, stream="video"):
)
self._c = torch.classes.torchvision.Video(path, stream)

def __next__(self):
def __next__(self) -> Dict[str, Any]:
"""Decodes and returns the next frame of the current stream.
Frames are encoded as a dict with mandatory
data and pts fields, where data is a tensor, and pts is a
Expand All @@ -126,10 +127,10 @@ def __next__(self):
raise StopIteration
return {"data": frame, "pts": pts}

def __iter__(self):
def __iter__(self) -> Iterator['VideoReader']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See mypy doc.

return self

def seek(self, time_s: float):
def seek(self, time_s: float) -> 'VideoReader':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be replaced with VideoReader once we drop Python 3.6 support. See here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.
I learn something new everyday because of review from torchvision maintainers 😇

There are couple of more places where I have used this workaround in typing in torchvision. I will keep a note of this point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox yup fully agree! But for now (this PR), I guess we keep the string version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's right, nothing we can do for now.

"""Seek within current stream.

Args:
Expand All @@ -144,15 +145,15 @@ def seek(self, time_s: float):
self._c.seek(time_s)
return self

def get_metadata(self):
def get_metadata(self) -> Dict[str, Any]:
"""Returns video metadata

Returns:
(dict): dictionary containing duration and frame rate for every stream
"""
return self._c.get_metadata()

def set_current_stream(self, stream: str):
def set_current_stream(self, stream: str) -> bool:
"""Set current stream.
Explicitly define the stream we are operating on.

Expand Down