Skip to content

Commit 13f7a71

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Added InStereo2k dataset (#6347)
Summary: * added SceneFLow variant datasets * Changed split name to variant name * removed trailing commented code line * Added InStereo2k dataset * Added Sintel Stereo dataset * small refactor in tests * Fixed doc formatting. * candidate fix for FileNotFound on windows test * Adressing comments * Added Sintel Stereo dataset * small refactor in tests * Fixed doc formatting. * candidate fix for FileNotFound on windows test * Adressing comments * rebased on main * lint fix * Added InStereo2k dataset Reviewed By: datumbox Differential Revision: D38824220 fbshipit-source-id: fa9ef743affa5f10a4a0a1ba130deccb6df859ca
1 parent 4e602a8 commit 13f7a71

File tree

4 files changed

+136
-6
lines changed

4 files changed

+136
-6
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Stereo Matching
113113
Kitti2015Stereo
114114
SceneFlowStereo
115115
SintelStereo
116-
116+
InStereo2k
117117

118118
Image pairs
119119
~~~~~~~~~~~

test/test_datasets.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,11 @@ def inject_fake_data(self, tmpdir, config):
28632863
os.makedirs(scene_flow_dir, exist_ok=True)
28642864

28652865
variant_dir = scene_flow_dir / config["variant"]
2866+
variant_dir_prefixes = {
2867+
"Monkaa": 0,
2868+
"Driving": 2,
2869+
"FlyingThings3D": 2,
2870+
}
28662871
os.makedirs(variant_dir, exist_ok=True)
28672872

28682873
num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["variant"], 0)
@@ -2880,6 +2885,12 @@ def inject_fake_data(self, tmpdir, config):
28802885
os.makedirs(pass_dir, exist_ok=True)
28812886
os.makedirs(disp_dir, exist_ok=True)
28822887

2888+
for i in range(variant_dir_prefixes.get(config["variant"], 0)):
2889+
pass_dir = pass_dir / str(i)
2890+
disp_dir = disp_dir / str(i)
2891+
os.makedirs(pass_dir, exist_ok=True)
2892+
os.makedirs(disp_dir, exist_ok=True)
2893+
28832894
for direction in ["left", "right"]:
28842895
for scene_idx in range(num_examples):
28852896
os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True)
@@ -2916,6 +2927,49 @@ def test_bad_input(self):
29162927
pass
29172928

29182929

2930+
class InStereo2k(datasets_utils.ImageDatasetTestCase):
2931+
DATASET_CLASS = datasets.InStereo2k
2932+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2933+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2934+
2935+
@staticmethod
2936+
def _make_scene_folder(root: str, name: str, size: Tuple[int, int]):
2937+
root = pathlib.Path(root) / name
2938+
os.makedirs(root, exist_ok=True)
2939+
2940+
datasets_utils.create_image_file(root=root, name="left.png", size=(3, size[0], size[1]))
2941+
datasets_utils.create_image_file(root=root, name="right.png", size=(3, size[0], size[1]))
2942+
datasets_utils.create_image_file(root=root, name="left_disp.png", size=(1, size[0], size[1]))
2943+
datasets_utils.create_image_file(root=root, name="right_disp.png", size=(1, size[0], size[1]))
2944+
2945+
def inject_fake_data(self, tmpdir, config):
2946+
in_stereo_dir = pathlib.Path(tmpdir) / "InStereo2k"
2947+
os.makedirs(in_stereo_dir, exist_ok=True)
2948+
2949+
split_dir = pathlib.Path(in_stereo_dir) / config["split"]
2950+
os.makedirs(split_dir, exist_ok=True)
2951+
2952+
num_examples = {"train": 4, "test": 5}.get(config["split"], 0)
2953+
2954+
for i in range(num_examples):
2955+
self._make_scene_folder(split_dir, f"scene_{i:06d}", (100, 200))
2956+
2957+
return num_examples
2958+
2959+
def test_splits(self):
2960+
for split_name in ["train", "test"]:
2961+
with self.create_dataset(split=split_name) as (dataset, _):
2962+
for left, right, disparity in dataset:
2963+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2964+
2965+
def test_bad_input(self):
2966+
with pytest.raises(
2967+
ValueError, match="Unknown value 'bad' for argument split. Valid values are {'train', 'test'}."
2968+
):
2969+
with self.create_dataset(split="bad"):
2970+
pass
2971+
2972+
29192973
class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
29202974
DATASET_CLASS = datasets.SintelStereo
29212975
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both"))

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2-
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo, SintelStereo
2+
from ._stereo_matching import CarlaStereo, InStereo2k, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo, SintelStereo
33
from .caltech import Caltech101, Caltech256
44
from .celeba import CelebA
55
from .cifar import CIFAR10, CIFAR100
@@ -111,4 +111,5 @@
111111
"CarlaStereo",
112112
"SceneFlowStereo",
113113
"SintelStereo",
114+
"InStereo2k",
114115
)

torchvision/datasets/_stereo_matching.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,19 @@ def __init__(
438438

439439
root = root / variant
440440

441+
prefix_directories = {
442+
"Monkaa": Path("*"),
443+
"FlyingThings3D": Path("*") / "*" / "*",
444+
"Driving": Path("*") / "*" / "*",
445+
}
446+
441447
for p in passes:
442-
left_image_pattern = str(root / p / "*" / "left" / "*.png")
443-
right_image_pattern = str(root / p / "*" / "right" / "*.png")
448+
left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
449+
right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
444450
self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
445451

446-
left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm")
447-
right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm")
452+
left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
453+
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
448454
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
449455

450456
def _read_disparity(self, file_path: str) -> Tuple:
@@ -588,3 +594,72 @@ def __getitem__(self, index: int) -> Tuple:
588594
the valid_mask is a numpy array of shape (H, W).
589595
"""
590596
return super().__getitem__(index)
597+
598+
599+
class InStereo2k(StereoMatchingDataset):
600+
"""`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
601+
602+
The dataset is expected to have the following structre: ::
603+
604+
root
605+
InStereo2k
606+
train
607+
scene1
608+
left.png
609+
right.png
610+
left_disp.png
611+
right_disp.png
612+
...
613+
scene2
614+
...
615+
test
616+
scene1
617+
left.png
618+
right.png
619+
left_disp.png
620+
right_disp.png
621+
...
622+
scene2
623+
...
624+
625+
Args:
626+
root (string): Root directory where InStereo2k is located.
627+
split (string): Either "train" or "test".
628+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
629+
"""
630+
631+
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
632+
super().__init__(root, transforms)
633+
634+
root = Path(root) / "InStereo2k" / split
635+
636+
verify_str_arg(split, "split", valid_values=("train", "test"))
637+
638+
left_img_pattern = str(root / "*" / "left.png")
639+
right_img_pattern = str(root / "*" / "right.png")
640+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
641+
642+
left_disparity_pattern = str(root / "*" / "left_disp.png")
643+
right_disparity_pattern = str(root / "*" / "right_disp.png")
644+
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
645+
646+
def _read_disparity(self, file_path: str) -> Tuple:
647+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
648+
# unsqueeze disparity to (C, H, W)
649+
disparity_map = disparity_map[None, :, :] / 1024.0
650+
valid_mask = None
651+
return disparity_map, valid_mask
652+
653+
def __getitem__(self, index: int) -> Tuple:
654+
"""Return example at given index.
655+
656+
Args:
657+
index(int): The index of the example to retrieve
658+
659+
Returns:
660+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
661+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
662+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
663+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
664+
"""
665+
return super().__getitem__(index)

0 commit comments

Comments
 (0)