Skip to content

Commit 1dd1753

Browse files
committed
Fixed mypy errors. Addressed download checks.
1 parent ec550e8 commit 1dd1753

File tree

1 file changed

+47
-25
lines changed

1 file changed

+47
-25
lines changed

torchvision/datasets/_stereo_matching.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,34 +60,36 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
6060
super().__init__(root=root)
6161
self.transforms = transforms
6262

63-
self._images: List[Tuple] = []
64-
self._disparities: List[Tuple] = []
63+
self._images: List[Tuple[str, str]] = []
64+
self._disparities: List[Tuple[str, str]] = []
6565

6666
def _read_img(self, file_path: str) -> Image.Image:
6767
img = Image.open(file_path)
6868
if img.mode != "RGB":
6969
img = img.convert("RGB")
7070
return img
7171

72-
def _scan_pairs(self, left_pattern: str, right_pattern: str, fill_empty: bool = False) -> List[Tuple[str, str]]:
73-
left_paths = sorted(glob(left_pattern))
74-
right_paths = sorted(glob(right_pattern))
72+
def _scan_pairs(
73+
self, paths_left_pattern: str, paths_right_pattern: str, fill_empty: bool = False
74+
) -> List[Tuple[str, str]]:
75+
left_paths: List[str] = sorted(glob(paths_left_pattern))
76+
right_paths: List[str] = sorted(glob(paths_right_pattern))
7577

7678
# used when dealing with inexistent disparity for the right image
7779
if fill_empty:
7880
right_paths = list("" for _ in left_paths)
7981

8082
if not left_paths:
81-
raise FileNotFoundError(f"Could not find any files matching the patterns: {left_pattern}")
83+
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
8284

8385
if not right_paths:
84-
raise FileNotFoundError(f"Could not find any files matching the patterns: {right_pattern}")
86+
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
8587

8688
if len(left_paths) != len(right_paths):
8789
raise ValueError(
8890
f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
89-
f"left pattern: {left_pattern}\n"
90-
f"right pattern: {right_pattern}\n"
91+
f"left pattern: {paths_left_pattern}\n"
92+
f"right pattern: {paths_right_pattern}\n"
9193
)
9294

9395
images = list((left, right) for left, right in zip(left_paths, right_paths))
@@ -387,6 +389,7 @@ def __init__(
387389
self._download_dataset(root)
388390

389391
root = Path(root) / "Middlebury2014"
392+
self.split = split
390393

391394
if not os.path.exists(root / split):
392395
raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
@@ -457,19 +460,26 @@ def _download_dataset(self, root: str):
457460
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
458461
# train and additional splits have 2 different calibration settings
459462
root = Path(root) / "Middlebury2014"
460-
for split_name, split_scenes in self.splits.items():
463+
download_split = self.split
464+
465+
for split_name, split_scenes in (download_split, self.splits[download_split]):
461466
if split_name == "test":
462467
continue
463468
split_root = root / split_name
464469
for scene in split_scenes:
465470
for calibration in ["perfect", "imperfect"]:
466471
scene_name = f"{scene}-{calibration}"
467472
scene_url = f"{base_url}/{scene_name}.zip"
468-
download_and_extract_archive(
469-
url=scene_url, filename=f"{scene_name}.zip", download_root=str(split_root), remove_finished=True
470-
)
471-
472-
if any(s not in os.listdir(root) for s in self.splits["test"]):
473+
# download the scene only if it doesn't exist
474+
if not os.path.exists(split_root / scene_name):
475+
download_and_extract_archive(
476+
url=scene_url,
477+
filename=f"{scene_name}.zip",
478+
download_root=str(split_root),
479+
remove_finished=True,
480+
)
481+
482+
if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
473483
# test split is downloaded from a different location
474484
test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
475485

@@ -550,13 +560,13 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
550560

551561
left_img_pattern = str(root / img_dir / "*" / "im0.png")
552562
right_img_pattern = str(root / img_dir / "*" / "im1.png")
553-
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
563+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
554564

555565
if split == "test":
556566
self._disparities = list(("", "") for _ in self._images)
557567
else:
558568
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
559-
self._disparities = self._scan_pairs(disparity_pattern, "", fill_empty=True)
569+
self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True)
560570

561571
def _read_disparity(self, file_path: str) -> Tuple:
562572
if not os.path.exists(file_path):
@@ -605,11 +615,11 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
605615

606616
left_img_pattern = str(root / "colored_0" / "*_10.png")
607617
right_img_pattern = str(root / "colored_1" / "*_10.png")
608-
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
618+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
609619

610620
if split == "train":
611621
disparity_pattern = str(root / "disp_noc" / "*.png")
612-
self._disparities = self._scan_pairs(disparity_pattern, "", fill_empty=True)
622+
self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True)
613623
else:
614624
self._disparities = list(("", "") for _ in self._images)
615625

@@ -676,12 +686,12 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
676686
root = Path(root) / "Kitti2015" / (split + "ing")
677687
left_img_pattern = str(root / "image_2" / "*.png")
678688
right_img_pattern = str(root / "image_3" / "*.png")
679-
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
689+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
680690

681691
if split == "train":
682692
left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
683693
right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
684-
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
694+
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
685695
else:
686696
self._disparities = list(("", "") for _ in self._images)
687697

@@ -750,21 +760,33 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
750760

751761
left_img_pattern = str(root / "training" / "final_left" / "*" / "*.png")
752762
right_img_pattern = str(root / "training" / "final_right" / "*" / "*.png")
753-
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
763+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
754764

755765
disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
756-
self._disparities = self._scan_pairs(disparity_pattern, "", fill_empty=True)
766+
self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True)
757767

758-
def _get_oclussion_mask_paths(self, file_path: str) -> List[str]:
768+
def _get_oclussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
759769
path_tokens = file_path.split(os.sep)
770+
rets = None
771+
760772
for idx in range(len(path_tokens) - 1):
761773
if path_tokens[idx] == "training" and path_tokens[idx + 1] == "disparities":
762774
pre_tokens = path_tokens[: idx + 1]
763775
post_tokens = path_tokens[idx + 2 :]
764-
return (
776+
rets = (
765777
"/".join(pre_tokens + ["occlusions"] + post_tokens),
766778
"/".join(pre_tokens + ["outofframe"] + post_tokens),
767779
)
780+
break
781+
782+
if rets is None:
783+
raise ValueError("Malformed file path: {}".format(file_path))
784+
785+
for path in rets:
786+
if not os.path.exists(path):
787+
raise ValueError(f"Could not find file {path}")
788+
789+
return rets
768790

769791
def _read_disparity(self, file_path: str) -> Tuple:
770792
if not os.path.exists(file_path):

0 commit comments

Comments
 (0)