|
11 | 11 | from torchvision.datasets import utils
|
12 | 12 | from common_utils import get_tmp_dir
|
13 | 13 | from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
|
14 |
| - cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root |
| 14 | + cityscapes_root, svhn_root, ucf101_root, places365_root, widerface_root, stl10_root |
15 | 15 | import xml.etree.ElementTree as ET
|
16 | 16 | from urllib.request import Request, urlopen
|
17 | 17 | import itertools
|
|
20 | 20 | import pickle
|
21 | 21 | from torchvision import datasets
|
22 | 22 | import torch
|
| 23 | +import shutil |
23 | 24 |
|
24 | 25 |
|
25 | 26 | try:
|
@@ -259,38 +260,6 @@ def test_svhn(self, mock_check):
|
259 | 260 | dataset = torchvision.datasets.SVHN(root, split="extra")
|
260 | 261 | self.generic_classification_dataset_test(dataset, num_images=2)
|
261 | 262 |
|
262 |
| - @mock.patch('torchvision.datasets.voc.download_extract') |
263 |
| - def test_voc_parse_xml(self, mock_download_extract): |
264 |
| - with voc_root() as root: |
265 |
| - dataset = torchvision.datasets.VOCDetection(root) |
266 |
| - |
267 |
| - single_object_xml = """<annotation> |
268 |
| - <object> |
269 |
| - <name>cat</name> |
270 |
| - </object> |
271 |
| - </annotation>""" |
272 |
| - multiple_object_xml = """<annotation> |
273 |
| - <object> |
274 |
| - <name>cat</name> |
275 |
| - </object> |
276 |
| - <object> |
277 |
| - <name>dog</name> |
278 |
| - </object> |
279 |
| - </annotation>""" |
280 |
| - |
281 |
| - single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml)) |
282 |
| - multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml)) |
283 |
| - |
284 |
| - self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}}) |
285 |
| - self.assertEqual(multiple_object_parsed, |
286 |
| - {'annotation': { |
287 |
| - 'object': [{ |
288 |
| - 'name': 'cat' |
289 |
| - }, { |
290 |
| - 'name': 'dog' |
291 |
| - }] |
292 |
| - }}) |
293 |
| - |
294 | 263 | @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
|
295 | 264 | def test_ucf101(self):
|
296 | 265 | cached_meta_data = None
|
@@ -756,5 +725,119 @@ def test_attr_names(self):
|
756 | 725 | self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
|
757 | 726 |
|
758 | 727 |
|
| 728 | +class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): |
| 729 | + DATASET_CLASS = datasets.VOCSegmentation |
| 730 | + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image) |
| 731 | + |
| 732 | + CONFIGS = ( |
| 733 | + *datasets_utils.combinations_grid( |
| 734 | + year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval") |
| 735 | + ), |
| 736 | + dict(year="2007", image_set="test"), |
| 737 | + dict(year="2007-test", image_set="test"), |
| 738 | + ) |
| 739 | + |
| 740 | + def inject_fake_data(self, tmpdir, config): |
| 741 | + year, is_test_set = ( |
| 742 | + ("2007", True) |
| 743 | + if config["year"] == "2007-test" or config["image_set"] == "test" |
| 744 | + else (config["year"], False) |
| 745 | + ) |
| 746 | + image_set = config["image_set"] |
| 747 | + |
| 748 | + base_dir = pathlib.Path(tmpdir) |
| 749 | + if year == "2011": |
| 750 | + base_dir /= "TrainVal" |
| 751 | + base_dir = base_dir / "VOCdevkit" / f"VOC{year}" |
| 752 | + os.makedirs(base_dir) |
| 753 | + |
| 754 | + num_images, num_images_per_image_set = self._create_image_set_files(base_dir, "ImageSets", is_test_set) |
| 755 | + datasets_utils.create_image_folder(base_dir, "JPEGImages", lambda idx: f"{idx:06d}.jpg", num_images) |
| 756 | + |
| 757 | + datasets_utils.create_image_folder(base_dir, "SegmentationClass", lambda idx: f"{idx:06d}.png", num_images) |
| 758 | + annotation = self._create_annotation_files(base_dir, "Annotations", num_images) |
| 759 | + |
| 760 | + return dict(num_examples=num_images_per_image_set[image_set], annotation=annotation) |
| 761 | + |
| 762 | + def _create_image_set_files(self, root, name, is_test_set): |
| 763 | + root = pathlib.Path(root) / name |
| 764 | + src = pathlib.Path(root) / "Main" |
| 765 | + os.makedirs(src, exist_ok=True) |
| 766 | + |
| 767 | + idcs = dict(train=(0, 1, 2), val=(3, 4), test=(5,)) |
| 768 | + idcs["trainval"] = (*idcs["train"], *idcs["val"]) |
| 769 | + |
| 770 | + for image_set in ("test",) if is_test_set else ("train", "val", "trainval"): |
| 771 | + self._create_image_set_file(src, image_set, idcs[image_set]) |
| 772 | + |
| 773 | + shutil.copytree(src, root / "Segmentation") |
| 774 | + |
| 775 | + num_images = max(itertools.chain(*idcs.values())) + 1 |
| 776 | + num_images_per_image_set = dict([(image_set, len(idcs_)) for image_set, idcs_ in idcs.items()]) |
| 777 | + return num_images, num_images_per_image_set |
| 778 | + |
| 779 | + def _create_image_set_file(self, root, image_set, idcs): |
| 780 | + with open(pathlib.Path(root) / f"{image_set}.txt", "w") as fh: |
| 781 | + fh.writelines([f"{idx:06d}\n" for idx in idcs]) |
| 782 | + |
| 783 | + def _create_annotation_files(self, root, name, num_images): |
| 784 | + root = pathlib.Path(root) / name |
| 785 | + os.makedirs(root) |
| 786 | + |
| 787 | + for idx in range(num_images): |
| 788 | + annotation = self._create_annotation_file(root, f"{idx:06d}.xml") |
| 789 | + |
| 790 | + return annotation |
| 791 | + |
| 792 | + def _create_annotation_file(self, root, name): |
| 793 | + def add_child(parent, name, text=None): |
| 794 | + child = ET.SubElement(parent, name) |
| 795 | + child.text = text |
| 796 | + return child |
| 797 | + |
| 798 | + def add_name(obj, name="dog"): |
| 799 | + add_child(obj, "name", name) |
| 800 | + return name |
| 801 | + |
| 802 | + def add_bndbox(obj, bndbox=None): |
| 803 | + if bndbox is None: |
| 804 | + bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"} |
| 805 | + |
| 806 | + obj = add_child(obj, "bndbox") |
| 807 | + for name, text in bndbox.items(): |
| 808 | + add_child(obj, name, text) |
| 809 | + |
| 810 | + return bndbox |
| 811 | + |
| 812 | + annotation = ET.Element("annotation") |
| 813 | + obj = add_child(annotation, "object") |
| 814 | + data = dict(name=add_name(obj), bndbox=add_bndbox(obj)) |
| 815 | + |
| 816 | + with open(pathlib.Path(root) / name, "wb") as fh: |
| 817 | + fh.write(ET.tostring(annotation)) |
| 818 | + |
| 819 | + return data |
| 820 | + |
| 821 | + |
| 822 | +class VOCDetectionTestCase(VOCSegmentationTestCase): |
| 823 | + DATASET_CLASS = datasets.VOCDetection |
| 824 | + FEATURE_TYPES = (PIL.Image.Image, dict) |
| 825 | + |
| 826 | + def test_annotations(self): |
| 827 | + with self.create_dataset() as (dataset, info): |
| 828 | + _, target = dataset[0] |
| 829 | + |
| 830 | + self.assertIn("annotation", target) |
| 831 | + annotation = target["annotation"] |
| 832 | + |
| 833 | + self.assertIn("object", annotation) |
| 834 | + objects = annotation["object"] |
| 835 | + |
| 836 | + self.assertEqual(len(objects), 1) |
| 837 | + object = objects[0] |
| 838 | + |
| 839 | + self.assertEqual(object, info["annotation"]) |
| 840 | + |
| 841 | + |
759 | 842 | if __name__ == "__main__":
|
760 | 843 | unittest.main()
|
0 commit comments