Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
SintelStereo

Image pairs
~~~~~~~~~~~
Expand Down
76 changes: 76 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,5 +2841,81 @@ def test_train_splits(self):
datasets_utils.shape_test_for_stereo(left, right, disparity)


class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SintelStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))

def inject_fake_data(self, tmpdir, config):
sintel_dir = pathlib.Path(tmpdir) / "Sintel"
os.makedirs(sintel_dir, exist_ok=True)

split_dir = pathlib.Path(sintel_dir) / "training"
os.makedirs(split_dir, exist_ok=True)

# a single setting, since there are no splits
num_examples = {"final": 2, "clean": 2}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure the logic is picked up on

Suggested change
num_examples = {"final": 2, "clean": 2}
num_examples = {"final": 2, "clean": 3}

pass_names = {
"final": ["final"],
"clean": ["clean"],
"both": ["final", "clean"],
}.get(config["pass_name"], [])

for p in pass_names:
for view in [f"{p}_left", f"{p}_right"]:
root = split_dir / view
os.makedirs(root, exist_ok=True)

datasets_utils.create_image_folder(
root=root,
name="scene1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples[p],
size=(3, 100, 200),
)

datasets_utils.create_image_folder(
root=split_dir / "occlusions",
name="scene1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=max(num_examples.values()),
size=(1, 100, 200),
)

datasets_utils.create_image_folder(
root=split_dir / "outofframe",
name="scene1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=max(num_examples.values()),
size=(1, 100, 200),
)

datasets_utils.create_image_folder(
root=split_dir / "disparities",
name="scene1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=max(num_examples.values()),
size=(3, 100, 200),
)

if config["pass_name"] == "both":
num_examples = sum(num_examples.values())
else:
num_examples = num_examples.get(config["pass_name"], 0)

return num_examples

def test_splits(self):
for pass_name in ["final", "clean", "both"]:
with self.create_dataset(pass_name=pass_name) as (dataset, _):
for left, right, disparity, valid_mask in dataset:
datasets_utils.shape_test_for_stereo(left, right, disparity, valid_mask)

def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
with self.create_dataset(pass_name="bad"):
pass


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SintelStereo
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -109,4 +109,5 @@
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"SintelStereo",
)
123 changes: 123 additions & 0 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import os
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
Expand Down Expand Up @@ -359,3 +360,125 @@ def __getitem__(self, index: int) -> Tuple:
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)


class SintelStereo(StereoMatchingDataset):
""" "Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" "Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
"""Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.


The dataset is expected to have the following structure: ::

root
Sintel
training
final_left
scene1
img1.png
img2.png
...
...
final_right
scene2
img1.png
img2.png
...
...
disparities
scene1
img1.png
img2.png
...
...
occlusions
scene1
img1.png
img2.png
...
...
outofframe
scene1
img1.png
img2.png
...
...

Args:
root (string): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final" or "clean".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass_name (string): The name of the pass to use, either "final" or "clean".
pass_name (string): The name of the pass to use, either "final", "clean", or "both".

transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None):
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))

root = Path(root) / "Sintel"
pass_names = {
"final": ["final"],
"clean": ["clean"],
"both": ["final", "clean"],
}[pass_name]

for p in pass_names:
left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)

disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
self._disparities += self._scan_pairs(disparity_pattern, None)

def _get_oclussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
# helper function to get the occlusion mask paths
# a path will look like .../.../.../training/disparities/scene1/img1.png
# we want to get something like .../.../.../training/occlusions/scene1/img1.png
fpath = Path(file_path)
basename = fpath.name
scenedir = fpath.parent
# the parent of the scenedir is actually the disparity dir
sampledir = scenedir.parent.parent

occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)

if not os.path.exists(occlusion_path):
raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")

if not os.path.exists(outofframe_path):
raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")

return occlusion_path, outofframe_path

def _read_disparity(self, file_path: str) -> Tuple:
if file_path is None:
return None, None

# disparity decoding as per Sintel instructions in the README provided with the dataset
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
r, g, b = np.split(disparity_map, 3, axis=-1)
disparity_map = r * 4 + g / (2**6) + b / (2**14)
# reshape into (C, H, W) format
disparity_map = np.transpose(disparity_map, (2, 0, 1))
# find the appropiate file paths
occlued_mask_path, out_of_frame_mask_path = self._get_oclussion_mask_paths(file_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, typo

Suggested change
occlued_mask_path, out_of_frame_mask_path = self._get_oclussion_mask_paths(file_path)
occluded_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)

# occlusion masks
valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
# out of frame masks
off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
# combine the masks together
valid_mask = np.logical_and(off_mask, valid_mask)
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:.
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
the valid_mask is a numpy array of shape (H, W).
"""
return super().__getitem__(index)