diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 742344e6b0f..d0ec1b406f3 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,4 +1,5 @@ import torch +from typing import Any, Dict, Iterator from ._video_opt import ( Timebase, @@ -33,13 +34,13 @@ if _HAS_VIDEO_OPT: - def _has_video_opt(): + def _has_video_opt() -> bool: return True else: - def _has_video_opt(): + def _has_video_opt() -> bool: return False @@ -99,7 +100,7 @@ class VideoReader: Currently available options include ``['video', 'audio']`` """ - def __init__(self, path, stream="video"): + def __init__(self, path: str, stream: str = "video") -> None: if not _has_video_opt(): raise RuntimeError( "Not compiled with video_reader support, " @@ -109,7 +110,7 @@ def __init__(self, path, stream="video"): ) self._c = torch.classes.torchvision.Video(path, stream) - def __next__(self): + def __next__(self) -> Dict[str, Any]: """Decodes and returns the next frame of the current stream. Frames are encoded as a dict with mandatory data and pts fields, where data is a tensor, and pts is a @@ -126,10 +127,10 @@ def __next__(self): raise StopIteration return {"data": frame, "pts": pts} - def __iter__(self): + def __iter__(self) -> Iterator['VideoReader']: return self - def seek(self, time_s: float): + def seek(self, time_s: float) -> 'VideoReader': """Seek within current stream. Args: @@ -144,7 +145,7 @@ def seek(self, time_s: float): self._c.seek(time_s) return self - def get_metadata(self): + def get_metadata(self) -> Dict[str, Any]: """Returns video metadata Returns: @@ -152,7 +153,7 @@ def get_metadata(self): """ return self._c.get_metadata() - def set_current_stream(self, stream: str): + def set_current_stream(self, stream: str) -> bool: """Set current stream. Explicitly define the stream we are operating on.