Skip to content

Commit e326295

Browse files
committed
Broken down PR(#6269). Added an additional dataset
1 parent 9b84859 commit e326295

File tree

4 files changed

+546
-0
lines changed

4 files changed

+546
-0
lines changed

test/datasets_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
from collections import defaultdict
1717
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
1818

19+
import numpy as np
20+
1921
import PIL
2022
import PIL.Image
2123
import pytest
2224
import torch
2325
import torchvision.datasets
2426
import torchvision.io
2527
from common_utils import disable_console_output, get_tmp_dir
28+
from torchvision.transforms.functional import get_dimensions
2629

2730

2831
__all__ = [
@@ -748,6 +751,49 @@ def size(idx: int) -> Tuple[int, int, int]:
748751
]
749752

750753

754+
def shape_test_for_stereo_gt_w_mask(
755+
left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray, valid_mask: np.ndarray
756+
):
757+
left_dims = get_dimensions(left)
758+
right_dims = get_dimensions(right)
759+
c, h, w = left_dims
760+
# check that left and right are the same size
761+
assert left_dims == right_dims
762+
# check general shapes
763+
assert c == 3
764+
assert disparity.ndim == 3
765+
assert disparity.shape == (1, h, w)
766+
# check that valid mask is the same size as the disparity
767+
768+
_, dh, dw = disparity.shape
769+
mh, mw = valid_mask.shape
770+
assert dh == mh
771+
assert dw == mw
772+
773+
774+
def shape_test_for_stereo_gt_no_mask(left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray):
775+
left_dims = get_dimensions(left)
776+
right_dims = get_dimensions(right)
777+
c, h, w = left_dims
778+
# check that left and right are the same size
779+
assert left_dims == right_dims
780+
# check general shapes
781+
assert c == 3
782+
assert disparity.ndim == 3
783+
assert disparity.shape == (1, h, w)
784+
785+
786+
def shape_test_for_stereo_no_gt(left: PIL.Image.Image, right: PIL.Image.Image, disparity: None):
787+
left_dims = get_dimensions(left)
788+
right_dims = get_dimensions(right)
789+
c, _, _ = left_dims
790+
# check that left and right are the same size
791+
assert left_dims == right_dims
792+
# check general shapes
793+
assert c == 3
794+
assert disparity is None
795+
796+
751797
@requires_lazy_imports("av")
752798
def create_video_file(
753799
root: Union[pathlib.Path, str],

test/test_datasets.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414
import xml.etree.ElementTree as ET
1515
import zipfile
16+
from typing import Union
1617

1718
import datasets_utils
1819
import numpy as np
@@ -2671,5 +2672,172 @@ def inject_fake_data(self, tmpdir: str, config):
26712672
return len(sampled_classes) * num_images_per_class[config["split"]]
26722673

26732674

2675+
class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
2676+
DATASET_CLASS = datasets.Kitti2012Stereo
2677+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2678+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
2679+
2680+
def inject_fake_data(self, tmpdir, config):
2681+
kitti_dir = pathlib.Path(tmpdir) / "Kitti2012"
2682+
os.makedirs(kitti_dir, exist_ok=True)
2683+
2684+
split_dir = kitti_dir / (config["split"] + "ing")
2685+
os.makedirs(split_dir, exist_ok=True)
2686+
2687+
num_examples = {"train": 4, "test": 3}.get(config["split"], 0)
2688+
2689+
datasets_utils.create_image_folder(
2690+
root=split_dir,
2691+
name="colored_0",
2692+
file_name_fn=lambda i: f"{i:06d}_10.png",
2693+
num_examples=num_examples,
2694+
size=(3, 100, 200),
2695+
)
2696+
datasets_utils.create_image_folder(
2697+
root=split_dir,
2698+
name="colored_1",
2699+
file_name_fn=lambda i: f"{i:06d}_10.png",
2700+
num_examples=num_examples,
2701+
size=(3, 100, 200),
2702+
)
2703+
2704+
if config["split"] == "train":
2705+
datasets_utils.create_image_folder(
2706+
root=split_dir,
2707+
name="disp_noc",
2708+
file_name_fn=lambda i: f"{i:06d}.png",
2709+
num_examples=num_examples,
2710+
# Kitti2012 uses a single channel image for disparities
2711+
size=(1, 100, 200),
2712+
)
2713+
2714+
return num_examples
2715+
2716+
def test_train_splits(self):
2717+
for split in ["train"]:
2718+
with self.create_dataset(split=split) as (dataset, _):
2719+
for left, right, disparity, mask in dataset:
2720+
assert mask is None
2721+
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)
2722+
2723+
def test_test_split(self):
2724+
for split in ["test"]:
2725+
with self.create_dataset(split=split) as (dataset, _):
2726+
for left, right, disparity, mask in dataset:
2727+
assert mask is None
2728+
datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity)
2729+
2730+
def test_bad_input(self):
2731+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2732+
with self.create_dataset(split="bad"):
2733+
pass
2734+
2735+
2736+
class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
2737+
DATASET_CLASS = datasets.Kitti2015Stereo
2738+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2739+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
2740+
2741+
def inject_fake_data(self, tmpdir, config):
2742+
kitti_dir = pathlib.Path(tmpdir) / "Kitti2015"
2743+
os.makedirs(kitti_dir, exist_ok=True)
2744+
2745+
split_dir = kitti_dir / (config["split"] + "ing")
2746+
os.makedirs(split_dir, exist_ok=True)
2747+
2748+
num_examples = {"train": 4, "test": 6}.get(config["split"], 0)
2749+
2750+
datasets_utils.create_image_folder(
2751+
root=split_dir,
2752+
name="image_2",
2753+
file_name_fn=lambda i: f"{i:06d}_10.png",
2754+
num_examples=num_examples,
2755+
size=(3, 100, 200),
2756+
)
2757+
datasets_utils.create_image_folder(
2758+
root=split_dir,
2759+
name="image_3",
2760+
file_name_fn=lambda i: f"{i:06d}_10.png",
2761+
num_examples=num_examples,
2762+
size=(3, 100, 200),
2763+
)
2764+
2765+
if config["split"] == "train":
2766+
datasets_utils.create_image_folder(
2767+
root=split_dir,
2768+
name="disp_occ_0",
2769+
file_name_fn=lambda i: f"{i:06d}.png",
2770+
num_examples=num_examples,
2771+
# Kitti2015 uses a single channel image for disparities
2772+
size=(1, 100, 200),
2773+
)
2774+
2775+
datasets_utils.create_image_folder(
2776+
root=split_dir,
2777+
name="disp_occ_1",
2778+
file_name_fn=lambda i: f"{i:06d}.png",
2779+
num_examples=num_examples,
2780+
# Kitti2015 uses a single channel image for disparities
2781+
size=(1, 100, 200),
2782+
)
2783+
2784+
return num_examples
2785+
2786+
def test_train_splits(self):
2787+
for split in ["train"]:
2788+
with self.create_dataset(split=split) as (dataset, _):
2789+
for left, right, disparity, mask in dataset:
2790+
assert mask is None
2791+
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)
2792+
2793+
def test_test_split(self):
2794+
for split in ["test"]:
2795+
with self.create_dataset(split=split) as (dataset, _):
2796+
for left, right, disparity, mask in dataset:
2797+
assert mask is None
2798+
datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity)
2799+
2800+
def test_bad_input(self):
2801+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2802+
with self.create_dataset(split="bad"):
2803+
pass
2804+
2805+
2806+
class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
2807+
DATASET_CLASS = datasets.CarlaStereo
2808+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None))
2809+
2810+
@staticmethod
2811+
def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]):
2812+
# make the root_dir if it does not exits
2813+
os.makedirs(root_dir, exist_ok=True)
2814+
2815+
for i in range(num_examples):
2816+
scene_dir = pathlib.Path(root_dir) / f"scene_{i}"
2817+
os.makedirs(scene_dir, exist_ok=True)
2818+
# populate with left right images
2819+
datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))
2820+
datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))
2821+
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm"))
2822+
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm"))
2823+
2824+
def inject_fake_data(self, tmpdir, config):
2825+
carla_dir = pathlib.Path(tmpdir) / "carla-highres"
2826+
os.makedirs(carla_dir, exist_ok=True)
2827+
2828+
split_dir = pathlib.Path(carla_dir) / "trainingF"
2829+
os.makedirs(split_dir, exist_ok=True)
2830+
2831+
num_examples = 6
2832+
self._create_scene_folders(num_examples=num_examples, root_dir=split_dir)
2833+
2834+
return num_examples
2835+
2836+
def test_train_splits(self):
2837+
with self.create_dataset() as (dataset, _):
2838+
for left, right, disparity in dataset:
2839+
datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity)
2840+
2841+
26742842
if __name__ == "__main__":
26752843
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2+
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
23
from .caltech import Caltech101, Caltech256
34
from .celeba import CelebA
45
from .cifar import CIFAR10, CIFAR100
@@ -105,4 +106,7 @@
105106
"FGVCAircraft",
106107
"EuroSAT",
107108
"RenderedSST2",
109+
"Kitti2012Stereo",
110+
"Kitti2015Stereo",
111+
"CarlaStereo",
108112
)

0 commit comments

Comments
 (0)