Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
CREStereo

Image pairs
~~~~~~~~~~~
Expand Down
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,5 +2841,36 @@ def test_train_splits(self):
datasets_utils.shape_test_for_stereo(left, right, disparity)


class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CREStereo
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None))

def inject_fake_data(self, tmpdir, config):
crestereo_dir = pathlib.Path(tmpdir) / "CREStereo"
os.makedirs(crestereo_dir, exist_ok=True)

examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5}

for category_name in ["shapenet", "reflective", "tree", "hole"]:
split_dir = crestereo_dir / category_name
os.makedirs(split_dir, exist_ok=True)
num_examples = examples[category_name]

for idx in range(num_examples):
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100))
# these are going to end up being gray scale images
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100))

return sum(examples.values())

def test_splits(self):
with self.create_dataset() as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo(left, right, disparity)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from ._stereo_matching import CarlaStereo, CREStereo, Kitti2012Stereo, Kitti2015Stereo
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -109,4 +109,5 @@
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"CREStereo",
)
88 changes: 88 additions & 0 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,91 @@ def __getitem__(self, index: int) -> Tuple:
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)


class CREStereo(StereoMatchingDataset):
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.

The dataset is expected to have the following structure: ::

root
CREStereo
tree
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
img2_left.jpg
img2_right.jpg
img2_left.disp.jpg
img2_right.disp.jpg
...
shapenet
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
reflective
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
hole
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...

Args:
root (str): Root directory of the dataset.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms)

root = Path(root) / "CREStereo"

dirs = ["shapenet", "reflective", "tree", "hole"]

for s in dirs:
left_image_pattern = str(root / s / "*_left.jpg")
right_image_pattern = str(root / s / "*_right.jpg")
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
self._images += imgs

left_disparity_pattern = str(root / s / "*_left.disp.png")
right_disparity_pattern = str(root / s / "*_right.disp.png")
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities

def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :]
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return super().__getitem__(index)