Skip to content

Commit 419d872

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Splitting Stereo Dataset PR(#6269) (#6311)
Summary: * Broken down PR(#6269). Added an additional dataset * Removed some types. Store None instead of "". Merged test util functions. * minor mypy fixes. minor doc fixes * reformated docstring * Added additional line-skips Reviewed By: NicolasHug Differential Revision: D38351752 fbshipit-source-id: 376714fcdd49cb474670ce8e6e959507a517ee46
1 parent e2e1db6 commit 419d872

File tree

5 files changed

+576
-0
lines changed

5 files changed

+576
-0
lines changed

docs/source/datasets.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ Optical Flow
101101
KittiFlow
102102
Sintel
103103

104+
Stereo Matching
105+
~~~~~~~~~~~~~~~
106+
107+
.. autosummary::
108+
:toctree: generated/
109+
:template: class_dataset.rst
110+
111+
CarlaStereo
112+
Kitti2012Stereo
113+
Kitti2015Stereo
114+
104115
Image pairs
105116
~~~~~~~~~~~
106117

test/datasets_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
from collections import defaultdict
1717
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
1818

19+
import numpy as np
20+
1921
import PIL
2022
import PIL.Image
2123
import pytest
2224
import torch
2325
import torchvision.datasets
2426
import torchvision.io
2527
from common_utils import disable_console_output, get_tmp_dir
28+
from torchvision.transforms.functional import get_dimensions
2629

2730

2831
__all__ = [
@@ -748,6 +751,33 @@ def size(idx: int) -> Tuple[int, int, int]:
748751
]
749752

750753

754+
def shape_test_for_stereo(
755+
left: PIL.Image.Image,
756+
right: PIL.Image.Image,
757+
disparity: Optional[np.ndarray] = None,
758+
valid_mask: Optional[np.ndarray] = None,
759+
):
760+
left_dims = get_dimensions(left)
761+
right_dims = get_dimensions(right)
762+
c, h, w = left_dims
763+
# check that left and right are the same size
764+
assert left_dims == right_dims
765+
assert c == 3
766+
767+
# check that the disparity has the same spatial dimensions
768+
# as the input
769+
if disparity is not None:
770+
assert disparity.ndim == 3
771+
assert disparity.shape == (1, h, w)
772+
773+
if valid_mask is not None:
774+
# check that valid mask is the same size as the disparity
775+
_, dh, dw = disparity.shape
776+
mh, mw = valid_mask.shape
777+
assert dh == mh
778+
assert dw == mw
779+
780+
751781
@requires_lazy_imports("av")
752782
def create_video_file(
753783
root: Union[pathlib.Path, str],

test/test_datasets.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414
import xml.etree.ElementTree as ET
1515
import zipfile
16+
from typing import Union
1617

1718
import datasets_utils
1819
import numpy as np
@@ -2671,5 +2672,174 @@ def inject_fake_data(self, tmpdir: str, config):
26712672
return len(sampled_classes) * num_images_per_class[config["split"]]
26722673

26732674

2675+
class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
2676+
DATASET_CLASS = datasets.Kitti2012Stereo
2677+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2678+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
2679+
2680+
def inject_fake_data(self, tmpdir, config):
2681+
kitti_dir = pathlib.Path(tmpdir) / "Kitti2012"
2682+
os.makedirs(kitti_dir, exist_ok=True)
2683+
2684+
split_dir = kitti_dir / (config["split"] + "ing")
2685+
os.makedirs(split_dir, exist_ok=True)
2686+
2687+
num_examples = {"train": 4, "test": 3}.get(config["split"], 0)
2688+
2689+
datasets_utils.create_image_folder(
2690+
root=split_dir,
2691+
name="colored_0",
2692+
file_name_fn=lambda i: f"{i:06d}_10.png",
2693+
num_examples=num_examples,
2694+
size=(3, 100, 200),
2695+
)
2696+
datasets_utils.create_image_folder(
2697+
root=split_dir,
2698+
name="colored_1",
2699+
file_name_fn=lambda i: f"{i:06d}_10.png",
2700+
num_examples=num_examples,
2701+
size=(3, 100, 200),
2702+
)
2703+
2704+
if config["split"] == "train":
2705+
datasets_utils.create_image_folder(
2706+
root=split_dir,
2707+
name="disp_noc",
2708+
file_name_fn=lambda i: f"{i:06d}.png",
2709+
num_examples=num_examples,
2710+
# Kitti2012 uses a single channel image for disparities
2711+
size=(1, 100, 200),
2712+
)
2713+
2714+
return num_examples
2715+
2716+
def test_train_splits(self):
2717+
for split in ["train"]:
2718+
with self.create_dataset(split=split) as (dataset, _):
2719+
for left, right, disparity, mask in dataset:
2720+
assert mask is None
2721+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2722+
2723+
def test_test_split(self):
2724+
for split in ["test"]:
2725+
with self.create_dataset(split=split) as (dataset, _):
2726+
for left, right, disparity, mask in dataset:
2727+
assert mask is None
2728+
assert disparity is None
2729+
datasets_utils.shape_test_for_stereo(left, right)
2730+
2731+
def test_bad_input(self):
2732+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2733+
with self.create_dataset(split="bad"):
2734+
pass
2735+
2736+
2737+
class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
2738+
DATASET_CLASS = datasets.Kitti2015Stereo
2739+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2740+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
2741+
2742+
def inject_fake_data(self, tmpdir, config):
2743+
kitti_dir = pathlib.Path(tmpdir) / "Kitti2015"
2744+
os.makedirs(kitti_dir, exist_ok=True)
2745+
2746+
split_dir = kitti_dir / (config["split"] + "ing")
2747+
os.makedirs(split_dir, exist_ok=True)
2748+
2749+
num_examples = {"train": 4, "test": 6}.get(config["split"], 0)
2750+
2751+
datasets_utils.create_image_folder(
2752+
root=split_dir,
2753+
name="image_2",
2754+
file_name_fn=lambda i: f"{i:06d}_10.png",
2755+
num_examples=num_examples,
2756+
size=(3, 100, 200),
2757+
)
2758+
datasets_utils.create_image_folder(
2759+
root=split_dir,
2760+
name="image_3",
2761+
file_name_fn=lambda i: f"{i:06d}_10.png",
2762+
num_examples=num_examples,
2763+
size=(3, 100, 200),
2764+
)
2765+
2766+
if config["split"] == "train":
2767+
datasets_utils.create_image_folder(
2768+
root=split_dir,
2769+
name="disp_occ_0",
2770+
file_name_fn=lambda i: f"{i:06d}.png",
2771+
num_examples=num_examples,
2772+
# Kitti2015 uses a single channel image for disparities
2773+
size=(1, 100, 200),
2774+
)
2775+
2776+
datasets_utils.create_image_folder(
2777+
root=split_dir,
2778+
name="disp_occ_1",
2779+
file_name_fn=lambda i: f"{i:06d}.png",
2780+
num_examples=num_examples,
2781+
# Kitti2015 uses a single channel image for disparities
2782+
size=(1, 100, 200),
2783+
)
2784+
2785+
return num_examples
2786+
2787+
def test_train_splits(self):
2788+
for split in ["train"]:
2789+
with self.create_dataset(split=split) as (dataset, _):
2790+
for left, right, disparity, mask in dataset:
2791+
assert mask is None
2792+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2793+
2794+
def test_test_split(self):
2795+
for split in ["test"]:
2796+
with self.create_dataset(split=split) as (dataset, _):
2797+
for left, right, disparity, mask in dataset:
2798+
assert mask is None
2799+
assert disparity is None
2800+
datasets_utils.shape_test_for_stereo(left, right)
2801+
2802+
def test_bad_input(self):
2803+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2804+
with self.create_dataset(split="bad"):
2805+
pass
2806+
2807+
2808+
class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
2809+
DATASET_CLASS = datasets.CarlaStereo
2810+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None))
2811+
2812+
@staticmethod
2813+
def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]):
2814+
# make the root_dir if it does not exits
2815+
os.makedirs(root_dir, exist_ok=True)
2816+
2817+
for i in range(num_examples):
2818+
scene_dir = pathlib.Path(root_dir) / f"scene_{i}"
2819+
os.makedirs(scene_dir, exist_ok=True)
2820+
# populate with left right images
2821+
datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))
2822+
datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))
2823+
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm"))
2824+
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm"))
2825+
2826+
def inject_fake_data(self, tmpdir, config):
2827+
carla_dir = pathlib.Path(tmpdir) / "carla-highres"
2828+
os.makedirs(carla_dir, exist_ok=True)
2829+
2830+
split_dir = pathlib.Path(carla_dir) / "trainingF"
2831+
os.makedirs(split_dir, exist_ok=True)
2832+
2833+
num_examples = 6
2834+
self._create_scene_folders(num_examples=num_examples, root_dir=split_dir)
2835+
2836+
return num_examples
2837+
2838+
def test_train_splits(self):
2839+
with self.create_dataset() as (dataset, _):
2840+
for left, right, disparity in dataset:
2841+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2842+
2843+
26742844
if __name__ == "__main__":
26752845
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2+
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
23
from .caltech import Caltech101, Caltech256
34
from .celeba import CelebA
45
from .cifar import CIFAR10, CIFAR100
@@ -105,4 +106,7 @@
105106
"FGVCAircraft",
106107
"EuroSAT",
107108
"RenderedSST2",
109+
"Kitti2012Stereo",
110+
"Kitti2015Stereo",
111+
"CarlaStereo",
108112
)

0 commit comments

Comments
 (0)