Skip to content

Commit 6624ed5

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Added SceneFLow variant datasets (#6345)
Summary: * added SceneFLow variant datasets * Changed split name to variant name * removed trailing commented code line Reviewed By: datumbox Differential Revision: D38824231 fbshipit-source-id: 14dc283f11df26287fe6446946b441f51eb82181
1 parent 6bfca02 commit 6624ed5

File tree

4 files changed

+185
-2
lines changed

4 files changed

+185
-2
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Stereo Matching
111111
CarlaStereo
112112
Kitti2012Stereo
113113
Kitti2015Stereo
114+
SceneFlowStereo
114115

115116
Image pairs
116117
~~~~~~~~~~~

test/test_datasets.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import unittest
1414
import xml.etree.ElementTree as ET
1515
import zipfile
16-
from typing import Union
16+
from typing import Callable, Tuple, Union
1717

1818
import datasets_utils
1919
import numpy as np
@@ -2841,5 +2841,80 @@ def test_train_splits(self):
28412841
datasets_utils.shape_test_for_stereo(left, right, disparity)
28422842

28432843

2844+
class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
2845+
DATASET_CLASS = datasets.SceneFlowStereo
2846+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2847+
variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both")
2848+
)
2849+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2850+
2851+
@staticmethod
2852+
def _create_pfm_folder(
2853+
root: str, name: str, file_name_fn: Callable[..., str], num_examples: int, size: Tuple[int, int]
2854+
) -> None:
2855+
root = pathlib.Path(root) / name
2856+
os.makedirs(root, exist_ok=True)
2857+
2858+
for i in range(num_examples):
2859+
datasets_utils.make_fake_pfm_file(size[0], size[1], root / file_name_fn(i))
2860+
2861+
def inject_fake_data(self, tmpdir, config):
2862+
scene_flow_dir = pathlib.Path(tmpdir) / "SceneFlow"
2863+
os.makedirs(scene_flow_dir, exist_ok=True)
2864+
2865+
variant_dir = scene_flow_dir / config["variant"]
2866+
os.makedirs(variant_dir, exist_ok=True)
2867+
2868+
num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["variant"], 0)
2869+
2870+
passes = {
2871+
"clean": ["frames_cleanpass"],
2872+
"final": ["frames_finalpass"],
2873+
"both": ["frames_cleanpass", "frames_finalpass"],
2874+
}.get(config["pass_name"], [])
2875+
2876+
for pass_dir_name in passes:
2877+
# create pass directories
2878+
pass_dir = variant_dir / pass_dir_name
2879+
disp_dir = variant_dir / "disparity"
2880+
os.makedirs(pass_dir, exist_ok=True)
2881+
os.makedirs(disp_dir, exist_ok=True)
2882+
2883+
for direction in ["left", "right"]:
2884+
for scene_idx in range(num_examples):
2885+
os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True)
2886+
datasets_utils.create_image_folder(
2887+
root=pass_dir / f"scene_{scene_idx:06d}",
2888+
name=direction,
2889+
file_name_fn=lambda i: f"{i:06d}.png",
2890+
num_examples=1,
2891+
size=(3, 200, 100),
2892+
)
2893+
2894+
os.makedirs(disp_dir / f"scene_{scene_idx:06d}", exist_ok=True)
2895+
self._create_pfm_folder(
2896+
root=disp_dir / f"scene_{scene_idx:06d}",
2897+
name=direction,
2898+
file_name_fn=lambda i: f"{i:06d}.pfm",
2899+
num_examples=1,
2900+
size=(100, 200),
2901+
)
2902+
2903+
if config["pass_name"] == "both":
2904+
num_examples *= 2
2905+
return num_examples
2906+
2907+
def test_splits(self):
2908+
for variant_name, pass_name in itertools.product(["FlyingThings3D", "Driving", "Monkaa"], ["clean", "final"]):
2909+
with self.create_dataset(variant=variant_name, pass_name=pass_name) as (dataset, _):
2910+
for left, right, disparity in dataset:
2911+
datasets_utils.shape_test_for_stereo(left, right, disparity)
2912+
2913+
def test_bad_input(self):
2914+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument variant"):
2915+
with self.create_dataset(variant="bad"):
2916+
pass
2917+
2918+
28442919
if __name__ == "__main__":
28452920
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2-
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
2+
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo
33
from .caltech import Caltech101, Caltech256
44
from .celeba import CelebA
55
from .cifar import CIFAR10, CIFAR100
@@ -109,4 +109,5 @@
109109
"Kitti2012Stereo",
110110
"Kitti2015Stereo",
111111
"CarlaStereo",
112+
"SceneFlowStereo",
112113
)

torchvision/datasets/_stereo_matching.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,109 @@ def __getitem__(self, index: int) -> Tuple:
359359
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
360360
"""
361361
return super().__getitem__(index)
362+
363+
364+
class SceneFlowStereo(StereoMatchingDataset):
365+
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
366+
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
367+
368+
The dataset is expected to have the following structre: ::
369+
370+
root
371+
SceneFlow
372+
Monkaa
373+
frames_cleanpass
374+
scene1
375+
left
376+
img1.png
377+
img2.png
378+
right
379+
img1.png
380+
img2.png
381+
scene2
382+
left
383+
img1.png
384+
img2.png
385+
right
386+
img1.png
387+
img2.png
388+
frames_finalpass
389+
scene1
390+
left
391+
img1.png
392+
img2.png
393+
right
394+
img1.png
395+
img2.png
396+
...
397+
...
398+
disparity
399+
scene1
400+
left
401+
img1.pfm
402+
img2.pfm
403+
right
404+
img1.pfm
405+
img2.pfm
406+
FlyingThings3D
407+
...
408+
...
409+
410+
Args:
411+
root (string): Root directory where SceneFlow is located.
412+
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
413+
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
414+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
415+
416+
"""
417+
418+
def __init__(
419+
self,
420+
root: str,
421+
variant: str = "FlyingThings3D",
422+
pass_name: str = "clean",
423+
transforms: Optional[Callable] = None,
424+
):
425+
super().__init__(root, transforms)
426+
427+
root = Path(root) / "SceneFlow"
428+
429+
verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
430+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
431+
432+
passes = {
433+
"clean": ["frames_cleanpass"],
434+
"final": ["frames_finalpass"],
435+
"both": ["frames_cleanpass", "frames_finalpass"],
436+
}[pass_name]
437+
438+
root = root / variant
439+
440+
for p in passes:
441+
left_image_pattern = str(root / p / "*" / "left" / "*.png")
442+
right_image_pattern = str(root / p / "*" / "right" / "*.png")
443+
self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
444+
445+
left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm")
446+
right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm")
447+
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
448+
449+
def _read_disparity(self, file_path: str) -> Tuple:
450+
disparity_map = _read_pfm_file(file_path)
451+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
452+
valid_mask = None
453+
return disparity_map, valid_mask
454+
455+
def __getitem__(self, index: int) -> Tuple:
456+
"""Return example at given index.
457+
458+
Args:
459+
index(int): The index of the example to retrieve
460+
461+
Returns:
462+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
463+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
464+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
465+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
466+
"""
467+
return super().__getitem__(index)

0 commit comments

Comments
 (0)