Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c057189
added SceneFLow variant datasets
TeodorPoncu Aug 2, 2022
996df9b
Changed split name to variant name
TeodorPoncu Aug 2, 2022
9f85dc1
removed trailing commented code line
TeodorPoncu Aug 2, 2022
3993bad
Added InStereo2k dataset
TeodorPoncu Aug 2, 2022
c5a154d
Added Sintel Stereo dataset
TeodorPoncu Aug 2, 2022
95db561
small refactor in tests
TeodorPoncu Aug 2, 2022
43c41af
Fixed doc formatting.
TeodorPoncu Aug 2, 2022
1f2c72e
candidate fix for FileNotFound on windows test
TeodorPoncu Aug 3, 2022
633e50b
Adressing comments
TeodorPoncu Aug 3, 2022
62bbe44
Merge branch 'main' into add-stereo-flyingthings
TeodorPoncu Aug 3, 2022
a3c6cb9
Merge branch 'main' into add-sintel-stereo-dataset
TeodorPoncu Aug 3, 2022
35eba77
Added Sintel Stereo dataset
TeodorPoncu Aug 2, 2022
7f3f32a
small refactor in tests
TeodorPoncu Aug 2, 2022
78a5f2d
Fixed doc formatting.
TeodorPoncu Aug 2, 2022
4cc28cb
candidate fix for FileNotFound on windows test
TeodorPoncu Aug 3, 2022
5138c56
Adressing comments
TeodorPoncu Aug 3, 2022
34b262c
rebased on main
TeodorPoncu Aug 4, 2022
d84b0ec
Merge branch 'add-sintel-stereo-dataset' of https://github.com/pytorc…
TeodorPoncu Aug 4, 2022
17bd858
Merge branch 'main' into add-sintel-stereo-dataset
TeodorPoncu Aug 4, 2022
7dc5399
Merge branch 'main' into add-sintel-stereo-dataset
TeodorPoncu Aug 8, 2022
70681aa
Merge branch 'main' into add-sintel-stereo-dataset
TeodorPoncu Aug 16, 2022
e144f48
lint fix
TeodorPoncu Aug 16, 2022
9395716
Added InStereo2k dataset
TeodorPoncu Aug 2, 2022
845f9fd
Merge branch 'add-instereo2k-dataset' of https://github.com/pytorch/v…
TeodorPoncu Aug 17, 2022
20ac644
Merge branch 'main' into add-instereo2k-dataset
TeodorPoncu Aug 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
InStereo2k

Image pairs
~~~~~~~~~~~
Expand Down
43 changes: 42 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import unittest
import xml.etree.ElementTree as ET
import zipfile
from typing import Union
from typing import Tuple, Union

import datasets_utils
import numpy as np
Expand Down Expand Up @@ -2841,5 +2841,46 @@ def test_train_splits(self):
datasets_utils.shape_test_for_stereo(left, right, disparity)


class InStereo2k(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.InStereo2k
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))

@staticmethod
def _make_scene_folder(root: str, name: str, size: Tuple[int, int]):
root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)

datasets_utils.create_image_file(root=root, name="left.png", size=(3, size[0], size[1]))
datasets_utils.create_image_file(root=root, name="right.png", size=(3, size[0], size[1]))
datasets_utils.create_image_file(root=root, name="left_disp.png", size=(1, size[0], size[1]))
datasets_utils.create_image_file(root=root, name="right_disp.png", size=(1, size[0], size[1]))

def inject_fake_data(self, tmpdir, config):
in_stereo_dir = pathlib.Path(tmpdir) / "InStereo2k"
os.makedirs(in_stereo_dir, exist_ok=True)

split_dir = pathlib.Path(in_stereo_dir) / config["split"]
os.makedirs(split_dir, exist_ok=True)

num_examples = {"train": 4, "test": 5}.get(config["split"], 0)

for i in range(num_examples):
self._make_scene_folder(split_dir, f"scene_{i:06d}", (100, 200))

return num_examples

def test_splits(self):
for split_name in ["train", "test"]:
with self.create_dataset(split=split_name) as (dataset, _):
for left, right, disparity in dataset:
datasets_utils.shape_test_for_stereo(left, right, disparity)

def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from ._stereo_matching import CarlaStereo, InStereo2k, Kitti2012Stereo, Kitti2015Stereo
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -109,4 +109,5 @@
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"InStereo2k",
)
69 changes: 69 additions & 0 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,72 @@ def __getitem__(self, index: int) -> Tuple:
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)


class InStereo2k(StereoMatchingDataset):
"""`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.

The dataset is expected to have the following structre: ::

root
InStereo2k
train
scene1
left.png
right.png
left_disp.png
right_disp.png
...
scene2
...
test
scene1
left.png
right.png
left_disp.png
right_disp.png
...
scene2
...

Args:
root (string): Root directory where InStereo2k is located.
split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split

verify_str_arg(split, "split", valid_values=("train", "test"))

left_img_pattern = str(root / "*" / "left.png")
right_img_pattern = str(root / "*" / "right.png")
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)

left_disparity_pattern = str(root / "*" / "left_disp.png")
right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :]
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)