Skip to content

Commit 15a9a93

Browse files
Added CREStereo dataset (#6351)
Co-authored-by: Joao Gomes <[email protected]>
1 parent 1d6a259 commit 15a9a93

File tree

4 files changed

+122
-0
lines changed

4 files changed

+122
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Stereo Matching
111111
CarlaStereo
112112
Kitti2012Stereo
113113
Kitti2015Stereo
114+
CREStereo
114115
FallingThingsStereo
115116
SceneFlowStereo
116117
SintelStereo

test/test_datasets.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,37 @@ def test_train_splits(self):
28412841
datasets_utils.shape_test_for_stereo(left, right, disparity)
28422842

28432843

2844+
class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
2845+
DATASET_CLASS = datasets.CREStereo
2846+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None))
2847+
2848+
def inject_fake_data(self, tmpdir, config):
2849+
crestereo_dir = pathlib.Path(tmpdir) / "CREStereo"
2850+
os.makedirs(crestereo_dir, exist_ok=True)
2851+
2852+
examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5}
2853+
2854+
for category_name in ["shapenet", "reflective", "tree", "hole"]:
2855+
split_dir = crestereo_dir / category_name
2856+
os.makedirs(split_dir, exist_ok=True)
2857+
num_examples = examples[category_name]
2858+
2859+
for idx in range(num_examples):
2860+
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100))
2861+
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100))
2862+
# these are going to end up being gray scale images
2863+
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100))
2864+
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100))
2865+
2866+
return sum(examples.values())
2867+
2868+
def test_splits(self):
2869+
with self.create_dataset() as (dataset, _):
2870+
for left, right, disparity, mask in dataset:
2871+
assert mask is None
2872+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2873+
2874+
28442875
class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
28452876
DATASET_CLASS = datasets.FallingThingsStereo
28462877
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both"))

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
22
from ._stereo_matching import (
33
CarlaStereo,
4+
CREStereo,
45
ETH3DStereo,
56
FallingThingsStereo,
67
InStereo2k,
@@ -118,6 +119,7 @@
118119
"Kitti2012Stereo",
119120
"Kitti2015Stereo",
120121
"CarlaStereo",
122+
"CREStereo",
121123
"FallingThingsStereo",
122124
"SceneFlowStereo",
123125
"SintelStereo",

torchvision/datasets/_stereo_matching.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,94 @@ def __getitem__(self, index: int) -> Tuple:
363363
return super().__getitem__(index)
364364

365365

366+
class CREStereo(StereoMatchingDataset):
367+
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
368+
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
369+
370+
The dataset is expected to have the following structure: ::
371+
372+
root
373+
CREStereo
374+
tree
375+
img1_left.jpg
376+
img1_right.jpg
377+
img1_left.disp.jpg
378+
img1_right.disp.jpg
379+
img2_left.jpg
380+
img2_right.jpg
381+
img2_left.disp.jpg
382+
img2_right.disp.jpg
383+
...
384+
shapenet
385+
img1_left.jpg
386+
img1_right.jpg
387+
img1_left.disp.jpg
388+
img1_right.disp.jpg
389+
...
390+
reflective
391+
img1_left.jpg
392+
img1_right.jpg
393+
img1_left.disp.jpg
394+
img1_right.disp.jpg
395+
...
396+
hole
397+
img1_left.jpg
398+
img1_right.jpg
399+
img1_left.disp.jpg
400+
img1_right.disp.jpg
401+
...
402+
403+
Args:
404+
root (str): Root directory of the dataset.
405+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
406+
"""
407+
408+
_has_built_in_disparity_mask = True
409+
410+
def __init__(
411+
self,
412+
root: str,
413+
transforms: Optional[Callable] = None,
414+
):
415+
super().__init__(root, transforms)
416+
417+
root = Path(root) / "CREStereo"
418+
419+
dirs = ["shapenet", "reflective", "tree", "hole"]
420+
421+
for s in dirs:
422+
left_image_pattern = str(root / s / "*_left.jpg")
423+
right_image_pattern = str(root / s / "*_right.jpg")
424+
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
425+
self._images += imgs
426+
427+
left_disparity_pattern = str(root / s / "*_left.disp.png")
428+
right_disparity_pattern = str(root / s / "*_right.disp.png")
429+
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
430+
self._disparities += disparities
431+
432+
def _read_disparity(self, file_path: str) -> Tuple:
433+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
434+
# unsqueeze the disparity map into (C, H, W) format
435+
disparity_map = disparity_map[None, :, :] / 256.0
436+
valid_mask = None
437+
return disparity_map, valid_mask
438+
439+
def __getitem__(self, index: int) -> Tuple:
440+
"""Return example at given index.
441+
442+
Args:
443+
index(int): The index of the example to retrieve
444+
445+
Returns:
446+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
447+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
448+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
449+
generate a valid mask.
450+
"""
451+
return super().__getitem__(index)
452+
453+
366454
class FallingThingsStereo(StereoMatchingDataset):
367455
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
368456

0 commit comments

Comments
 (0)