Skip to content

Commit 9662001

Browse files
authored
Added ETH3D stereo dataset (#6349)
* Added ETH3D stereo dataset * Small doc-reformating * Removed assertions with no use, changed np conversion * Added ETH3D stereo dataset * Removed assertions with no use, changed np conversion * rebased on main * Revert "Removed assertions with no use, changed np conversion" This reverts commit 1478a8c. * Update to np.bool instead of np.bool_ * lint and mypy nit fix * test nit
1 parent 330b6c9 commit 9662001

File tree

4 files changed

+204
-17
lines changed

4 files changed

+204
-17
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Stereo Matching
115115
SceneFlowStereo
116116
SintelStereo
117117
InStereo2k
118+
ETH3DStereo
118119

119120
Image pairs
120121
~~~~~~~~~~~

test/test_datasets.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,15 +2872,26 @@ def inject_fake_data(self, tmpdir, config):
28722872
os.makedirs(fallingthings_dir, exist_ok=True)
28732873

28742874
num_examples = {"single": 2, "mixed": 3, "both": 4}.get(config["variant"], 0)
2875+
28752876
variants = {
28762877
"single": ["single"],
28772878
"mixed": ["mixed"],
28782879
"both": ["single", "mixed"],
28792880
}.get(config["variant"], [])
28802881

2882+
variant_dir_prefixes = {
2883+
"single": 1,
2884+
"mixed": 0,
2885+
}
2886+
28812887
for variant_name in variants:
28822888
variant_dir = pathlib.Path(fallingthings_dir) / variant_name
28832889
os.makedirs(variant_dir, exist_ok=True)
2890+
2891+
for i in range(variant_dir_prefixes[variant_name]):
2892+
variant_dir = variant_dir / f"{i:02d}"
2893+
os.makedirs(variant_dir, exist_ok=True)
2894+
28842895
for i in range(num_examples):
28852896
self._make_scene_folder(
28862897
root=variant_dir,
@@ -3109,5 +3120,72 @@ def test_bad_input(self):
31093120
pass
31103121

31113122

3123+
class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase):
3124+
DATASET_CLASS = datasets.ETH3DStereo
3125+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
3126+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
3127+
3128+
@staticmethod
3129+
def _create_scene_folder(num_examples: int, root_dir: str):
3130+
# make the root_dir if it does not exits
3131+
root_dir = pathlib.Path(root_dir)
3132+
os.makedirs(root_dir, exist_ok=True)
3133+
3134+
for i in range(num_examples):
3135+
scene_dir = root_dir / f"scene_{i}"
3136+
os.makedirs(scene_dir, exist_ok=True)
3137+
# populate with left right images
3138+
datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))
3139+
datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))
3140+
3141+
@staticmethod
3142+
def _create_annotation_folder(num_examples: int, root_dir: str):
3143+
# make the root_dir if it does not exits
3144+
root_dir = pathlib.Path(root_dir)
3145+
os.makedirs(root_dir, exist_ok=True)
3146+
3147+
# create scene directories
3148+
for i in range(num_examples):
3149+
scene_dir = root_dir / f"scene_{i}"
3150+
os.makedirs(scene_dir, exist_ok=True)
3151+
# populate with a random png file for occlusion mask, and a pfm file for disparity
3152+
datasets_utils.create_image_file(root=scene_dir, name="mask0nocc.png", size=(1, 100, 100))
3153+
3154+
pfm_path = scene_dir / "disp0GT.pfm"
3155+
datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=pfm_path)
3156+
3157+
def inject_fake_data(self, tmpdir, config):
3158+
eth3d_dir = pathlib.Path(tmpdir) / "ETH3D"
3159+
3160+
num_examples = 2 if config["split"] == "train" else 3
3161+
3162+
split_name = "two_view_training" if config["split"] == "train" else "two_view_test"
3163+
split_dir = eth3d_dir / split_name
3164+
self._create_scene_folder(num_examples, split_dir)
3165+
3166+
if config["split"] == "train":
3167+
annot_dir = eth3d_dir / "two_view_training_gt"
3168+
self._create_annotation_folder(num_examples, annot_dir)
3169+
3170+
return num_examples
3171+
3172+
def test_training_splits(self):
3173+
with self.create_dataset(split="train") as (dataset, _):
3174+
for left, right, disparity, valid_mask in dataset:
3175+
datasets_utils.shape_test_for_stereo(left, right, disparity, valid_mask)
3176+
3177+
def test_testing_splits(self):
3178+
with self.create_dataset(split="test") as (dataset, _):
3179+
assert all(d == (None, None) for d in dataset._disparities)
3180+
for left, right, disparity, valid_mask in dataset:
3181+
assert valid_mask is None
3182+
datasets_utils.shape_test_for_stereo(left, right, disparity)
3183+
3184+
def test_bad_input(self):
3185+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
3186+
with self.create_dataset(split="bad"):
3187+
pass
3188+
3189+
31123190
if __name__ == "__main__":
31133191
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
22
from ._stereo_matching import (
33
CarlaStereo,
4+
ETH3DStereo,
45
FallingThingsStereo,
56
InStereo2k,
67
Kitti2012Stereo,
@@ -121,4 +122,5 @@
121122
"SceneFlowStereo",
122123
"SintelStereo",
123124
"InStereo2k",
125+
"ETH3DStereo",
124126
)

torchvision/datasets/_stereo_matching.py

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -371,19 +371,20 @@ class FallingThingsStereo(StereoMatchingDataset):
371371
root
372372
FallingThings
373373
single
374-
scene1
375-
_object_settings.json
376-
_camera_settings.json
377-
image1.left.depth.png
378-
image1.right.depth.png
379-
image1.left.jpg
380-
image1.right.jpg
381-
image2.left.depth.png
382-
image2.right.depth.png
383-
image2.left.jpg
384-
image2.right
385-
...
386-
scene2
374+
dir1
375+
scene1
376+
_object_settings.json
377+
_camera_settings.json
378+
image1.left.depth.png
379+
image1.right.depth.png
380+
image1.left.jpg
381+
image1.right.jpg
382+
image2.left.depth.png
383+
image2.right.depth.png
384+
image2.left.jpg
385+
image2.right
386+
...
387+
scene2
387388
...
388389
mixed
389390
scene1
@@ -420,13 +421,18 @@ def __init__(self, root: str, variant: str = "single", transforms: Optional[Call
420421
"both": ["single", "mixed"],
421422
}[variant]
422423

424+
split_prefix = {
425+
"single": Path("*") / "*",
426+
"mixed": Path("*"),
427+
}
428+
423429
for s in variants:
424-
left_img_pattern = str(root / s / "*" / "*.left.jpg")
425-
right_img_pattern = str(root / s / "*" / "*.right.jpg")
430+
left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
431+
right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
426432
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
427433

428-
left_disparity_pattern = str(root / s / "*" / "*.left.depth.png")
429-
right_disparity_pattern = str(root / s / "*" / "*.right.depth.png")
434+
left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
435+
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
430436
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
431437

432438
def _read_disparity(self, file_path: str) -> Tuple:
@@ -762,3 +768,103 @@ def __getitem__(self, index: int) -> Tuple:
762768
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
763769
"""
764770
return super().__getitem__(index)
771+
772+
773+
class ETH3DStereo(StereoMatchingDataset):
774+
"""ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
775+
776+
The dataset is expected to have the following structure: ::
777+
778+
root
779+
ETH3D
780+
two_view_training
781+
scene1
782+
im1.png
783+
im0.png
784+
images.txt
785+
cameras.txt
786+
calib.txt
787+
scene2
788+
im1.png
789+
im0.png
790+
images.txt
791+
cameras.txt
792+
calib.txt
793+
...
794+
two_view_training_gt
795+
scene1
796+
disp0GT.pfm
797+
mask0nocc.png
798+
scene2
799+
disp0GT.pfm
800+
mask0nocc.png
801+
...
802+
two_view_testing
803+
scene1
804+
im1.png
805+
im0.png
806+
images.txt
807+
cameras.txt
808+
calib.txt
809+
scene2
810+
im1.png
811+
im0.png
812+
images.txt
813+
cameras.txt
814+
calib.txt
815+
...
816+
817+
Args:
818+
root (string): Root directory of the ETH3D Dataset.
819+
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
820+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
821+
"""
822+
823+
_has_built_in_disparity_mask = True
824+
825+
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
826+
super().__init__(root, transforms)
827+
828+
verify_str_arg(split, "split", valid_values=("train", "test"))
829+
830+
root = Path(root) / "ETH3D"
831+
832+
img_dir = "two_view_training" if split == "train" else "two_view_test"
833+
anot_dir = "two_view_training_gt"
834+
835+
left_img_pattern = str(root / img_dir / "*" / "im0.png")
836+
right_img_pattern = str(root / img_dir / "*" / "im1.png")
837+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
838+
839+
if split == "test":
840+
self._disparities = list((None, None) for _ in self._images)
841+
else:
842+
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
843+
self._disparities = self._scan_pairs(disparity_pattern, None)
844+
845+
def _read_disparity(self, file_path: str) -> Tuple:
846+
# test split has no disparity maps
847+
if file_path is None:
848+
return None, None
849+
850+
disparity_map = _read_pfm_file(file_path)
851+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
852+
mask_path = Path(file_path).parent / "mask0nocc.png"
853+
valid_mask = Image.open(mask_path)
854+
valid_mask = np.asarray(valid_mask).astype(bool)
855+
return disparity_map, valid_mask
856+
857+
def __getitem__(self, index: int) -> Tuple:
858+
"""Return example at given index.
859+
860+
Args:
861+
index(int): The index of the example to retrieve
862+
863+
Returns:
864+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
865+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
866+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
867+
generate a valid mask.
868+
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
869+
"""
870+
return super().__getitem__(index)

0 commit comments

Comments
 (0)