Skip to content

Commit db71802

Browse files
authored
Add Sintel Stereo dataset (#6348)
* added SceneFLow variant datasets * Changed split name to variant name * removed trailing commented code line * Added Sintel Stereo dataset * small refactor in tests * Fixed doc formatting. * candidate fix for FileNotFound on windows test * Adressing comments * Added Sintel Stereo dataset * small refactor in tests * Fixed doc formatting. * candidate fix for FileNotFound on windows test * Adressing comments * rebased on main * lint fix
1 parent 701d773 commit db71802

File tree

4 files changed

+203
-1
lines changed

4 files changed

+203
-1
lines changed

docs/source/datasets.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ Stereo Matching
112112
Kitti2012Stereo
113113
Kitti2015Stereo
114114
SceneFlowStereo
115+
SintelStereo
116+
115117

116118
Image pairs
117119
~~~~~~~~~~~

test/test_datasets.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,5 +2916,81 @@ def test_bad_input(self):
29162916
pass
29172917

29182918

2919+
class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
2920+
DATASET_CLASS = datasets.SintelStereo
2921+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both"))
2922+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
2923+
2924+
def inject_fake_data(self, tmpdir, config):
2925+
sintel_dir = pathlib.Path(tmpdir) / "Sintel"
2926+
os.makedirs(sintel_dir, exist_ok=True)
2927+
2928+
split_dir = pathlib.Path(sintel_dir) / "training"
2929+
os.makedirs(split_dir, exist_ok=True)
2930+
2931+
# a single setting, since there are no splits
2932+
num_examples = {"final": 2, "clean": 3}
2933+
pass_names = {
2934+
"final": ["final"],
2935+
"clean": ["clean"],
2936+
"both": ["final", "clean"],
2937+
}.get(config["pass_name"], [])
2938+
2939+
for p in pass_names:
2940+
for view in [f"{p}_left", f"{p}_right"]:
2941+
root = split_dir / view
2942+
os.makedirs(root, exist_ok=True)
2943+
2944+
datasets_utils.create_image_folder(
2945+
root=root,
2946+
name="scene1",
2947+
file_name_fn=lambda i: f"{i:06d}.png",
2948+
num_examples=num_examples[p],
2949+
size=(3, 100, 200),
2950+
)
2951+
2952+
datasets_utils.create_image_folder(
2953+
root=split_dir / "occlusions",
2954+
name="scene1",
2955+
file_name_fn=lambda i: f"{i:06d}.png",
2956+
num_examples=max(num_examples.values()),
2957+
size=(1, 100, 200),
2958+
)
2959+
2960+
datasets_utils.create_image_folder(
2961+
root=split_dir / "outofframe",
2962+
name="scene1",
2963+
file_name_fn=lambda i: f"{i:06d}.png",
2964+
num_examples=max(num_examples.values()),
2965+
size=(1, 100, 200),
2966+
)
2967+
2968+
datasets_utils.create_image_folder(
2969+
root=split_dir / "disparities",
2970+
name="scene1",
2971+
file_name_fn=lambda i: f"{i:06d}.png",
2972+
num_examples=max(num_examples.values()),
2973+
size=(3, 100, 200),
2974+
)
2975+
2976+
if config["pass_name"] == "both":
2977+
num_examples = sum(num_examples.values())
2978+
else:
2979+
num_examples = num_examples.get(config["pass_name"], 0)
2980+
2981+
return num_examples
2982+
2983+
def test_splits(self):
2984+
for pass_name in ["final", "clean", "both"]:
2985+
with self.create_dataset(pass_name=pass_name) as (dataset, _):
2986+
for left, right, disparity, valid_mask in dataset:
2987+
datasets_utils.shape_test_for_stereo(left, right, disparity, valid_mask)
2988+
2989+
def test_bad_input(self):
2990+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
2991+
with self.create_dataset(pass_name="bad"):
2992+
pass
2993+
2994+
29192995
if __name__ == "__main__":
29202996
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2-
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo
2+
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo, SintelStereo
33
from .caltech import Caltech101, Caltech256
44
from .celeba import CelebA
55
from .cifar import CIFAR10, CIFAR100
@@ -110,4 +110,5 @@
110110
"Kitti2015Stereo",
111111
"CarlaStereo",
112112
"SceneFlowStereo",
113+
"SintelStereo",
113114
)

torchvision/datasets/_stereo_matching.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import os
23
from abc import ABC, abstractmethod
34
from glob import glob
45
from pathlib import Path
@@ -465,3 +466,125 @@ def __getitem__(self, index: int) -> Tuple:
465466
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
466467
"""
467468
return super().__getitem__(index)
469+
470+
471+
class SintelStereo(StereoMatchingDataset):
472+
"""Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
473+
474+
The dataset is expected to have the following structure: ::
475+
476+
root
477+
Sintel
478+
training
479+
final_left
480+
scene1
481+
img1.png
482+
img2.png
483+
...
484+
...
485+
final_right
486+
scene2
487+
img1.png
488+
img2.png
489+
...
490+
...
491+
disparities
492+
scene1
493+
img1.png
494+
img2.png
495+
...
496+
...
497+
occlusions
498+
scene1
499+
img1.png
500+
img2.png
501+
...
502+
...
503+
outofframe
504+
scene1
505+
img1.png
506+
img2.png
507+
...
508+
...
509+
510+
Args:
511+
root (string): Root directory where Sintel Stereo is located.
512+
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
513+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
514+
"""
515+
516+
_has_built_in_disparity_mask = True
517+
518+
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None):
519+
super().__init__(root, transforms)
520+
521+
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
522+
523+
root = Path(root) / "Sintel"
524+
pass_names = {
525+
"final": ["final"],
526+
"clean": ["clean"],
527+
"both": ["final", "clean"],
528+
}[pass_name]
529+
530+
for p in pass_names:
531+
left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
532+
right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
533+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
534+
535+
disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
536+
self._disparities += self._scan_pairs(disparity_pattern, None)
537+
538+
def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
539+
# helper function to get the occlusion mask paths
540+
# a path will look like .../.../.../training/disparities/scene1/img1.png
541+
# we want to get something like .../.../.../training/occlusions/scene1/img1.png
542+
fpath = Path(file_path)
543+
basename = fpath.name
544+
scenedir = fpath.parent
545+
# the parent of the scenedir is actually the disparity dir
546+
sampledir = scenedir.parent.parent
547+
548+
occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
549+
outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
550+
551+
if not os.path.exists(occlusion_path):
552+
raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
553+
554+
if not os.path.exists(outofframe_path):
555+
raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
556+
557+
return occlusion_path, outofframe_path
558+
559+
def _read_disparity(self, file_path: str) -> Tuple:
560+
if file_path is None:
561+
return None, None
562+
563+
# disparity decoding as per Sintel instructions in the README provided with the dataset
564+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
565+
r, g, b = np.split(disparity_map, 3, axis=-1)
566+
disparity_map = r * 4 + g / (2**6) + b / (2**14)
567+
# reshape into (C, H, W) format
568+
disparity_map = np.transpose(disparity_map, (2, 0, 1))
569+
# find the appropiate file paths
570+
occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
571+
# occlusion masks
572+
valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
573+
# out of frame masks
574+
off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
575+
# combine the masks together
576+
valid_mask = np.logical_and(off_mask, valid_mask)
577+
return disparity_map, valid_mask
578+
579+
def __getitem__(self, index: int) -> Tuple:
580+
"""Return example at given index.
581+
582+
Args:
583+
index(int): The index of the example to retrieve
584+
585+
Returns:
586+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
587+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
588+
the valid_mask is a numpy array of shape (H, W).
589+
"""
590+
return super().__getitem__(index)

0 commit comments

Comments
 (0)