@@ -60,34 +60,36 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
6060 super ().__init__ (root = root )
6161 self .transforms = transforms
6262
63- self ._images : List [Tuple ] = []
64- self ._disparities : List [Tuple ] = []
63+ self ._images : List [Tuple [ str , str ] ] = []
64+ self ._disparities : List [Tuple [ str , str ] ] = []
6565
6666 def _read_img (self , file_path : str ) -> Image .Image :
6767 img = Image .open (file_path )
6868 if img .mode != "RGB" :
6969 img = img .convert ("RGB" )
7070 return img
7171
72- def _scan_pairs (self , left_pattern : str , right_pattern : str , fill_empty : bool = False ) -> List [Tuple [str , str ]]:
73- left_paths = sorted (glob (left_pattern ))
74- right_paths = sorted (glob (right_pattern ))
72+ def _scan_pairs (
73+ self , paths_left_pattern : str , paths_right_pattern : str , fill_empty : bool = False
74+ ) -> List [Tuple [str , str ]]:
75+ left_paths : List [str ] = sorted (glob (paths_left_pattern ))
76+ right_paths : List [str ] = sorted (glob (paths_right_pattern ))
7577
7678 # used when dealing with inexistent disparity for the right image
7779 if fill_empty :
7880 right_paths = list ("" for _ in left_paths )
7981
8082 if not left_paths :
81- raise FileNotFoundError (f"Could not find any files matching the patterns: { left_pattern } " )
83+ raise FileNotFoundError (f"Could not find any files matching the patterns: { paths_left_pattern } " )
8284
8385 if not right_paths :
84- raise FileNotFoundError (f"Could not find any files matching the patterns: { right_pattern } " )
86+ raise FileNotFoundError (f"Could not find any files matching the patterns: { paths_right_pattern } " )
8587
8688 if len (left_paths ) != len (right_paths ):
8789 raise ValueError (
8890 f"Found { len (left_paths )} left files but { len (right_paths )} right files using:\n "
89- f"left pattern: { left_pattern } \n "
90- f"right pattern: { right_pattern } \n "
91+ f"left pattern: { paths_left_pattern } \n "
92+ f"right pattern: { paths_right_pattern } \n "
9193 )
9294
9395 images = list ((left , right ) for left , right in zip (left_paths , right_paths ))
@@ -387,6 +389,7 @@ def __init__(
387389 self ._download_dataset (root )
388390
389391 root = Path (root ) / "Middlebury2014"
392+ self .split = split
390393
391394 if not os .path .exists (root / split ):
392395 raise FileNotFoundError (f"The { split } directory was not found in the provided root directory" )
@@ -457,19 +460,26 @@ def _download_dataset(self, root: str):
457460 base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
458461 # train and additional splits have 2 different calibration settings
459462 root = Path (root ) / "Middlebury2014"
460- for split_name , split_scenes in self .splits .items ():
463+ download_split = self .split
464+
465+ for split_name , split_scenes in (download_split , self .splits [download_split ]):
461466 if split_name == "test" :
462467 continue
463468 split_root = root / split_name
464469 for scene in split_scenes :
465470 for calibration in ["perfect" , "imperfect" ]:
466471 scene_name = f"{ scene } -{ calibration } "
467472 scene_url = f"{ base_url } /{ scene_name } .zip"
468- download_and_extract_archive (
469- url = scene_url , filename = f"{ scene_name } .zip" , download_root = str (split_root ), remove_finished = True
470- )
471-
472- if any (s not in os .listdir (root ) for s in self .splits ["test" ]):
473+ # download the scene only if it doesn't exist
474+ if not os .path .exists (split_root / scene_name ):
475+ download_and_extract_archive (
476+ url = scene_url ,
477+ filename = f"{ scene_name } .zip" ,
478+ download_root = str (split_root ),
479+ remove_finished = True ,
480+ )
481+
482+ if any (s not in os .listdir (root / "test" ) for s in self .splits ["test" ]):
473483 # test split is downloaded from a different location
474484 test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
475485
@@ -550,13 +560,13 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
550560
551561 left_img_pattern = str (root / img_dir / "*" / "im0.png" )
552562 right_img_pattern = str (root / img_dir / "*" / "im1.png" )
553- self ._images = self ._scan_pairs (left_img_pattern , right_img_pattern )
563+ self ._images + = self ._scan_pairs (left_img_pattern , right_img_pattern )
554564
555565 if split == "test" :
556566 self ._disparities = list (("" , "" ) for _ in self ._images )
557567 else :
558568 disparity_pattern = str (root / anot_dir / "*" / "disp0GT.pfm" )
559- self ._disparities = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
569+ self ._disparities + = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
560570
561571 def _read_disparity (self , file_path : str ) -> Tuple :
562572 if not os .path .exists (file_path ):
@@ -605,11 +615,11 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
605615
606616 left_img_pattern = str (root / "colored_0" / "*_10.png" )
607617 right_img_pattern = str (root / "colored_1" / "*_10.png" )
608- self ._images = self ._scan_pairs (left_img_pattern , right_img_pattern )
618+ self ._images + = self ._scan_pairs (left_img_pattern , right_img_pattern )
609619
610620 if split == "train" :
611621 disparity_pattern = str (root / "disp_noc" / "*.png" )
612- self ._disparities = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
622+ self ._disparities + = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
613623 else :
614624 self ._disparities = list (("" , "" ) for _ in self ._images )
615625
@@ -676,12 +686,12 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
676686 root = Path (root ) / "Kitti2015" / (split + "ing" )
677687 left_img_pattern = str (root / "image_2" / "*.png" )
678688 right_img_pattern = str (root / "image_3" / "*.png" )
679- self ._images = self ._scan_pairs (left_img_pattern , right_img_pattern )
689+ self ._images + = self ._scan_pairs (left_img_pattern , right_img_pattern )
680690
681691 if split == "train" :
682692 left_disparity_pattern = str (root / "disp_occ_0" / "*.png" )
683693 right_disparity_pattern = str (root / "disp_occ_1" / "*.png" )
684- self ._disparities = self ._scan_pairs (left_disparity_pattern , right_disparity_pattern )
694+ self ._disparities + = self ._scan_pairs (left_disparity_pattern , right_disparity_pattern )
685695 else :
686696 self ._disparities = list (("" , "" ) for _ in self ._images )
687697
@@ -750,21 +760,33 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
750760
751761 left_img_pattern = str (root / "training" / "final_left" / "*" / "*.png" )
752762 right_img_pattern = str (root / "training" / "final_right" / "*" / "*.png" )
753- self ._images = self ._scan_pairs (left_img_pattern , right_img_pattern )
763+ self ._images + = self ._scan_pairs (left_img_pattern , right_img_pattern )
754764
755765 disparity_pattern = str (root / "training" / "disparities" / "*" / "*.png" )
756- self ._disparities = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
766+ self ._disparities + = self ._scan_pairs (disparity_pattern , "" , fill_empty = True )
757767
758- def _get_oclussion_mask_paths (self , file_path : str ) -> List [ str ]:
768+ def _get_oclussion_mask_paths (self , file_path : str ) -> Tuple [ str , str ]:
759769 path_tokens = file_path .split (os .sep )
770+ rets = None
771+
760772 for idx in range (len (path_tokens ) - 1 ):
761773 if path_tokens [idx ] == "training" and path_tokens [idx + 1 ] == "disparities" :
762774 pre_tokens = path_tokens [: idx + 1 ]
763775 post_tokens = path_tokens [idx + 2 :]
764- return (
776+ rets = (
765777 "/" .join (pre_tokens + ["occlusions" ] + post_tokens ),
766778 "/" .join (pre_tokens + ["outofframe" ] + post_tokens ),
767779 )
780+ break
781+
782+ if rets is None :
783+ raise ValueError ("Malformed file path: {}" .format (file_path ))
784+
785+ for path in rets :
786+ if not os .path .exists (path ):
787+ raise ValueError (f"Could not find file { path } " )
788+
789+ return rets
768790
769791 def _read_disparity (self , file_path : str ) -> Tuple :
770792 if not os .path .exists (file_path ):
0 commit comments