@@ -363,6 +363,94 @@ def __getitem__(self, index: int) -> Tuple:
363363 return super ().__getitem__ (index )
364364
365365
366+ class CREStereo (StereoMatchingDataset ):
367+ """Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
368+ Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
369+
370+ The dataset is expected to have the following structure: ::
371+
372+ root
373+ CREStereo
374+ tree
375+ img1_left.jpg
376+ img1_right.jpg
377+ img1_left.disp.jpg
378+ img1_right.disp.jpg
379+ img2_left.jpg
380+ img2_right.jpg
381+ img2_left.disp.jpg
382+ img2_right.disp.jpg
383+ ...
384+ shapenet
385+ img1_left.jpg
386+ img1_right.jpg
387+ img1_left.disp.jpg
388+ img1_right.disp.jpg
389+ ...
390+ reflective
391+ img1_left.jpg
392+ img1_right.jpg
393+ img1_left.disp.jpg
394+ img1_right.disp.jpg
395+ ...
396+ hole
397+ img1_left.jpg
398+ img1_right.jpg
399+ img1_left.disp.jpg
400+ img1_right.disp.jpg
401+ ...
402+
403+ Args:
404+ root (str): Root directory of the dataset.
405+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
406+ """
407+
408+ _has_built_in_disparity_mask = True
409+
410+ def __init__ (
411+ self ,
412+ root : str ,
413+ transforms : Optional [Callable ] = None ,
414+ ):
415+ super ().__init__ (root , transforms )
416+
417+ root = Path (root ) / "CREStereo"
418+
419+ dirs = ["shapenet" , "reflective" , "tree" , "hole" ]
420+
421+ for s in dirs :
422+ left_image_pattern = str (root / s / "*_left.jpg" )
423+ right_image_pattern = str (root / s / "*_right.jpg" )
424+ imgs = self ._scan_pairs (left_image_pattern , right_image_pattern )
425+ self ._images += imgs
426+
427+ left_disparity_pattern = str (root / s / "*_left.disp.png" )
428+ right_disparity_pattern = str (root / s / "*_right.disp.png" )
429+ disparities = self ._scan_pairs (left_disparity_pattern , right_disparity_pattern )
430+ self ._disparities += disparities
431+
432+ def _read_disparity (self , file_path : str ) -> Tuple :
433+ disparity_map = np .asarray (Image .open (file_path ), dtype = np .float32 )
434+ # unsqueeze the disparity map into (C, H, W) format
435+ disparity_map = disparity_map [None , :, :] / 256.0
436+ valid_mask = None
437+ return disparity_map , valid_mask
438+
439+ def __getitem__ (self , index : int ) -> Tuple :
440+ """Return example at given index.
441+
442+ Args:
443+ index(int): The index of the example to retrieve
444+
445+ Returns:
446+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
447+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
448+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
449+ generate a valid mask.
450+ """
451+ return super ().__getitem__ (index )
452+
453+
366454class FallingThingsStereo (StereoMatchingDataset ):
367455 """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
368456
0 commit comments