From 7779b12c5e01fbf1062e1ef6f782471568a1550c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 4 Oct 2021 09:15:41 +0100 Subject: [PATCH 1/2] Allow redefinition for mypy --- mypy.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy.ini b/mypy.ini index dac60e11ce0..4820cdf1661 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,6 +3,7 @@ files = torchvision show_error_codes = True pretty = True +allow_redefinition = True [mypy-torchvision.io._video_opt.*] From 626372afd418a741c0bd7bf34103f56365ec053a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 08:22:19 +0200 Subject: [PATCH 2/2] appease mypy --- torchvision/datasets/celeba.py | 8 +++----- torchvision/datasets/imagenet.py | 2 +- torchvision/prototype/datasets/_builtin/caltech.py | 8 ++++---- torchvision/prototype/datasets/_folder.py | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 91f0fc3f919..302f75087b7 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,7 +1,6 @@ import csv import os from collections import namedtuple -from functools import partial from typing import Any, Callable, List, Optional, Union, Tuple import PIL @@ -115,15 +114,14 @@ def _load_csv( filename: str, header: Optional[int] = None, ) -> CSV: - data, indices, headers = [], [], [] - - fn = partial(os.path.join, self.root, self.base_folder) - with open(fn(filename)) as csv_file: + with open(os.path.join(self.root, self.base_folder, filename)) as csv_file: data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True)) if header is not None: headers = data[header] data = data[header + 1 :] + else: + headers = [] indices = [row[0] for row in data] data = [row[1:] for row in data] diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 624294571aa..0fdb3395a5e 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -119,7 +119,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: """ import scipy.io as sio - def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]: + def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]: metafile = os.path.join(devkit_root, "data", "meta.mat") meta = sio.loadmat(metafile, squeeze_me=True)["synsets"] nums_children = list(zip(*meta))[4] diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index d2ce41c0d0f..b1d9970bd94 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -117,11 +117,11 @@ def _make_datapipe( images_dp, anns_dp = resource_dps images_dp = TarArchiveReader(images_dp) - images_dp = Filter(images_dp, self._is_not_background_image) + images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image) images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) anns_dp = TarArchiveReader(anns_dp) - anns_dp = Filter(anns_dp, self._is_ann) + anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann) dp = KeyZipper( images_dp, @@ -136,7 +136,7 @@ def _make_datapipe( def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_background_image) + dp: IterDataPipe = Filter(dp, self._is_not_background_image) dir_names = {pathlib.Path(path).parent.name for path, _ in dp} create_categories_file(HERE, self.name, sorted(dir_names)) @@ -185,7 +185,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_rogue_file) + dp: IterDataPipe = Filter(dp, self._is_not_rogue_file) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 55e48387d6a..581c59167ed 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -49,8 +49,8 @@ def from_data_folder( root = pathlib.Path(root).expanduser().resolve() categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" - dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) - dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) + dp = FileLister(str(root), recursive=recursive, masks=masks) + dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return (