Skip to content

Commit ccb7f45

Browse files
pmeierfmassa
andauthored
Add tests for VOC(Segmentation|Detection) and fix existing bugs (#3415)
* use common download utils in VOC and SBDataset * add tests for VOC * use common base class for VOC datasets * remove old voc test and fake data generation Co-authored-by: Francisco Massa <[email protected]>
1 parent 7b7cfdd commit ccb7f45

File tree

4 files changed

+216
-165
lines changed

4 files changed

+216
-165
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -369,19 +369,6 @@ def _make_mat(file):
369369
yield root
370370

371371

372-
@contextlib.contextmanager
373-
def voc_root():
374-
with get_tmp_dir() as tmp_dir:
375-
voc_dir = os.path.join(tmp_dir, 'VOCdevkit',
376-
'VOC2012', 'ImageSets', 'Main')
377-
os.makedirs(voc_dir)
378-
train_file = os.path.join(voc_dir, 'train.txt')
379-
with open(train_file, 'w') as f:
380-
f.write('test')
381-
382-
yield tmp_dir
383-
384-
385372
@contextlib.contextmanager
386373
def ucf101_root():
387374
with get_tmp_dir() as tmp_dir:

test/test_datasets.py

Lines changed: 116 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
1313
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
14-
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
14+
cityscapes_root, svhn_root, ucf101_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
1717
import itertools
@@ -20,6 +20,7 @@
2020
import pickle
2121
from torchvision import datasets
2222
import torch
23+
import shutil
2324

2425

2526
try:
@@ -259,38 +260,6 @@ def test_svhn(self, mock_check):
259260
dataset = torchvision.datasets.SVHN(root, split="extra")
260261
self.generic_classification_dataset_test(dataset, num_images=2)
261262

262-
@mock.patch('torchvision.datasets.voc.download_extract')
263-
def test_voc_parse_xml(self, mock_download_extract):
264-
with voc_root() as root:
265-
dataset = torchvision.datasets.VOCDetection(root)
266-
267-
single_object_xml = """<annotation>
268-
<object>
269-
<name>cat</name>
270-
</object>
271-
</annotation>"""
272-
multiple_object_xml = """<annotation>
273-
<object>
274-
<name>cat</name>
275-
</object>
276-
<object>
277-
<name>dog</name>
278-
</object>
279-
</annotation>"""
280-
281-
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
282-
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
283-
284-
self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
285-
self.assertEqual(multiple_object_parsed,
286-
{'annotation': {
287-
'object': [{
288-
'name': 'cat'
289-
}, {
290-
'name': 'dog'
291-
}]
292-
}})
293-
294263
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
295264
def test_ucf101(self):
296265
cached_meta_data = None
@@ -756,5 +725,119 @@ def test_attr_names(self):
756725
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
757726

758727

728+
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
729+
DATASET_CLASS = datasets.VOCSegmentation
730+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
731+
732+
CONFIGS = (
733+
*datasets_utils.combinations_grid(
734+
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
735+
),
736+
dict(year="2007", image_set="test"),
737+
dict(year="2007-test", image_set="test"),
738+
)
739+
740+
def inject_fake_data(self, tmpdir, config):
741+
year, is_test_set = (
742+
("2007", True)
743+
if config["year"] == "2007-test" or config["image_set"] == "test"
744+
else (config["year"], False)
745+
)
746+
image_set = config["image_set"]
747+
748+
base_dir = pathlib.Path(tmpdir)
749+
if year == "2011":
750+
base_dir /= "TrainVal"
751+
base_dir = base_dir / "VOCdevkit" / f"VOC{year}"
752+
os.makedirs(base_dir)
753+
754+
num_images, num_images_per_image_set = self._create_image_set_files(base_dir, "ImageSets", is_test_set)
755+
datasets_utils.create_image_folder(base_dir, "JPEGImages", lambda idx: f"{idx:06d}.jpg", num_images)
756+
757+
datasets_utils.create_image_folder(base_dir, "SegmentationClass", lambda idx: f"{idx:06d}.png", num_images)
758+
annotation = self._create_annotation_files(base_dir, "Annotations", num_images)
759+
760+
return dict(num_examples=num_images_per_image_set[image_set], annotation=annotation)
761+
762+
def _create_image_set_files(self, root, name, is_test_set):
763+
root = pathlib.Path(root) / name
764+
src = pathlib.Path(root) / "Main"
765+
os.makedirs(src, exist_ok=True)
766+
767+
idcs = dict(train=(0, 1, 2), val=(3, 4), test=(5,))
768+
idcs["trainval"] = (*idcs["train"], *idcs["val"])
769+
770+
for image_set in ("test",) if is_test_set else ("train", "val", "trainval"):
771+
self._create_image_set_file(src, image_set, idcs[image_set])
772+
773+
shutil.copytree(src, root / "Segmentation")
774+
775+
num_images = max(itertools.chain(*idcs.values())) + 1
776+
num_images_per_image_set = dict([(image_set, len(idcs_)) for image_set, idcs_ in idcs.items()])
777+
return num_images, num_images_per_image_set
778+
779+
def _create_image_set_file(self, root, image_set, idcs):
780+
with open(pathlib.Path(root) / f"{image_set}.txt", "w") as fh:
781+
fh.writelines([f"{idx:06d}\n" for idx in idcs])
782+
783+
def _create_annotation_files(self, root, name, num_images):
784+
root = pathlib.Path(root) / name
785+
os.makedirs(root)
786+
787+
for idx in range(num_images):
788+
annotation = self._create_annotation_file(root, f"{idx:06d}.xml")
789+
790+
return annotation
791+
792+
def _create_annotation_file(self, root, name):
793+
def add_child(parent, name, text=None):
794+
child = ET.SubElement(parent, name)
795+
child.text = text
796+
return child
797+
798+
def add_name(obj, name="dog"):
799+
add_child(obj, "name", name)
800+
return name
801+
802+
def add_bndbox(obj, bndbox=None):
803+
if bndbox is None:
804+
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
805+
806+
obj = add_child(obj, "bndbox")
807+
for name, text in bndbox.items():
808+
add_child(obj, name, text)
809+
810+
return bndbox
811+
812+
annotation = ET.Element("annotation")
813+
obj = add_child(annotation, "object")
814+
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
815+
816+
with open(pathlib.Path(root) / name, "wb") as fh:
817+
fh.write(ET.tostring(annotation))
818+
819+
return data
820+
821+
822+
class VOCDetectionTestCase(VOCSegmentationTestCase):
823+
DATASET_CLASS = datasets.VOCDetection
824+
FEATURE_TYPES = (PIL.Image.Image, dict)
825+
826+
def test_annotations(self):
827+
with self.create_dataset() as (dataset, info):
828+
_, target = dataset[0]
829+
830+
self.assertIn("annotation", target)
831+
annotation = target["annotation"]
832+
833+
self.assertIn("object", annotation)
834+
objects = annotation["object"]
835+
836+
self.assertEqual(len(objects), 1)
837+
object = objects[0]
838+
839+
self.assertEqual(object, info["annotation"])
840+
841+
759842
if __name__ == "__main__":
760843
unittest.main()

torchvision/datasets/sbd.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import numpy as np
77

88
from PIL import Image
9-
from .utils import download_url, verify_str_arg
10-
from .voc import download_extract
9+
from .utils import download_url, verify_str_arg, download_and_extract_archive
1110

1211

1312
class SBDataset(VisionDataset):
@@ -77,7 +76,7 @@ def __init__(
7776
mask_dir = os.path.join(sbd_root, 'cls')
7877

7978
if download:
80-
download_extract(self.url, self.root, self.filename, self.md5)
79+
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
8180
extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
8281
for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
8382
old_path = os.path.join(extracted_ds_root, f)

0 commit comments

Comments
 (0)