|
9 | 9 | import torchvision
|
10 | 10 | from torchvision.datasets import utils
|
11 | 11 | from common_utils import get_tmp_dir
|
| 12 | +from fakedata_generation import places365_root |
12 | 13 | import xml.etree.ElementTree as ET
|
13 | 14 | from urllib.request import Request, urlopen
|
14 | 15 | import itertools
|
|
40 | 41 | HAS_PYAV = False
|
41 | 42 |
|
42 | 43 |
|
| 44 | +class DatasetTestcase(unittest.TestCase): |
| 45 | + def generic_classification_dataset_test(self, dataset, num_images=1): |
| 46 | + self.assertEqual(len(dataset), num_images) |
| 47 | + img, target = dataset[0] |
| 48 | + self.assertTrue(isinstance(img, PIL.Image.Image)) |
| 49 | + self.assertTrue(isinstance(target, int)) |
| 50 | + |
| 51 | + def generic_segmentation_dataset_test(self, dataset, num_images=1): |
| 52 | + self.assertEqual(len(dataset), num_images) |
| 53 | + img, target = dataset[0] |
| 54 | + self.assertTrue(isinstance(img, PIL.Image.Image)) |
| 55 | + self.assertTrue(isinstance(target, PIL.Image.Image)) |
| 56 | + |
| 57 | + |
| 58 | +class Tester(DatasetTestcase): |
| 59 | + def test_places365(self): |
| 60 | + for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
| 61 | + with places365_root(split=split, small=small) as places365: |
| 62 | + root, data = places365 |
| 63 | + |
| 64 | + dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) |
| 65 | + self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"])) |
| 66 | + |
| 67 | + def test_places365_transforms(self): |
| 68 | + expected_image = "image" |
| 69 | + expected_target = "target" |
| 70 | + |
| 71 | + def transform(image): |
| 72 | + return expected_image |
| 73 | + |
| 74 | + def target_transform(target): |
| 75 | + return expected_target |
| 76 | + |
| 77 | + with places365_root() as places365: |
| 78 | + root, data = places365 |
| 79 | + |
| 80 | + dataset = torchvision.datasets.Places365( |
| 81 | + root, transform=transform, target_transform=target_transform, download=True |
| 82 | + ) |
| 83 | + actual_image, actual_target = dataset[0] |
| 84 | + |
| 85 | + self.assertEqual(actual_image, expected_image) |
| 86 | + self.assertEqual(actual_target, expected_target) |
| 87 | + |
| 88 | + def test_places365_devkit_download(self): |
| 89 | + for split in ("train-standard", "train-challenge", "val"): |
| 90 | + with self.subTest(split=split): |
| 91 | + with places365_root(split=split) as places365: |
| 92 | + root, data = places365 |
| 93 | + |
| 94 | + dataset = torchvision.datasets.Places365(root, split=split, download=True) |
| 95 | + |
| 96 | + with self.subTest("classes"): |
| 97 | + self.assertSequenceEqual(dataset.classes, data["classes"]) |
| 98 | + |
| 99 | + with self.subTest("class_to_idx"): |
| 100 | + self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"]) |
| 101 | + |
| 102 | + with self.subTest("imgs"): |
| 103 | + self.assertSequenceEqual(dataset.imgs, data["imgs"]) |
| 104 | + |
| 105 | + def test_places365_devkit_no_download(self): |
| 106 | + for split in ("train-standard", "train-challenge", "val"): |
| 107 | + with self.subTest(split=split): |
| 108 | + with places365_root(split=split) as places365: |
| 109 | + root, data = places365 |
| 110 | + |
| 111 | + with self.assertRaises(RuntimeError): |
| 112 | + torchvision.datasets.Places365(root, split=split, download=False) |
| 113 | + |
| 114 | + def test_places365_images_download(self): |
| 115 | + for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
| 116 | + with self.subTest(split=split, small=small): |
| 117 | + with places365_root(split=split, small=small) as places365: |
| 118 | + root, data = places365 |
| 119 | + |
| 120 | + dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True) |
| 121 | + |
| 122 | + assert all(os.path.exists(item[0]) for item in dataset.imgs) |
| 123 | + |
| 124 | + def test_places365_images_download_preexisting(self): |
| 125 | + split = "train-standard" |
| 126 | + small = False |
| 127 | + images_dir = "data_large_standard" |
| 128 | + |
| 129 | + with places365_root(split=split, small=small) as places365: |
| 130 | + root, data = places365 |
| 131 | + os.mkdir(os.path.join(root, images_dir)) |
| 132 | + |
| 133 | + with self.assertRaises(RuntimeError): |
| 134 | + torchvision.datasets.Places365(root, split=split, small=small, download=True) |
| 135 | + |
| 136 | + def test_places365_repr_smoke(self): |
| 137 | + with places365_root() as places365: |
| 138 | + root, data = places365 |
| 139 | + |
| 140 | + dataset = torchvision.datasets.Places365(root, download=True) |
| 141 | + self.assertIsInstance(repr(dataset), str) |
| 142 | + |
| 143 | + |
43 | 144 | class STL10TestCase(datasets_utils.ImageDatasetTestCase):
|
44 | 145 | DATASET_CLASS = datasets.STL10
|
45 | 146 | ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
|
@@ -1662,96 +1763,5 @@ def inject_fake_data(self, tmpdir, config):
|
1662 | 1763 | return num_examples
|
1663 | 1764 |
|
1664 | 1765 |
|
1665 |
| -class Places365TestCase(datasets_utils.ImageDatasetTestCase): |
1666 |
| - DATASET_CLASS = datasets.Places365 |
1667 |
| - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( |
1668 |
| - split=("train-standard", "train-challenge", "val"), |
1669 |
| - small=(False, True), |
1670 |
| - ) |
1671 |
| - _CATEGORIES = "categories_places365.txt" |
1672 |
| - # {split: file} |
1673 |
| - _FILE_LISTS = { |
1674 |
| - "train-standard": "places365_train_standard.txt", |
1675 |
| - "train-challenge": "places365_train_challenge.txt", |
1676 |
| - "val": "places365_val.txt", |
1677 |
| - } |
1678 |
| - # {(split, small): folder_name} |
1679 |
| - _IMAGES = { |
1680 |
| - ("train-standard", False): "data_large_standard", |
1681 |
| - ("train-challenge", False): "data_large_challenge", |
1682 |
| - ("val", False): "val_large", |
1683 |
| - ("train-standard", True): "data_256_standard", |
1684 |
| - ("train-challenge", True): "data_256_challenge", |
1685 |
| - ("val", True): "val_256", |
1686 |
| - } |
1687 |
| - # (class, idx) |
1688 |
| - _CATEGORIES_CONTENT = ( |
1689 |
| - ("/a/airfield", 0), |
1690 |
| - ("/a/apartment_building/outdoor", 8), |
1691 |
| - ("/b/badlands", 30), |
1692 |
| - ) |
1693 |
| - # (file, idx) |
1694 |
| - _FILE_LIST_CONTENT = ( |
1695 |
| - ("Places365_val_00000001.png", 0), |
1696 |
| - *((f"{category}/Places365_train_00000001.png", idx) |
1697 |
| - for category, idx in _CATEGORIES_CONTENT), |
1698 |
| - ) |
1699 |
| - |
1700 |
| - @staticmethod |
1701 |
| - def _make_txt(root, name, seq): |
1702 |
| - file = os.path.join(root, name) |
1703 |
| - with open(file, "w") as fh: |
1704 |
| - for text, idx in seq: |
1705 |
| - fh.write(f"{text} {idx}\n") |
1706 |
| - |
1707 |
| - @staticmethod |
1708 |
| - def _make_categories_txt(root, name): |
1709 |
| - Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT) |
1710 |
| - |
1711 |
| - @staticmethod |
1712 |
| - def _make_file_list_txt(root, name): |
1713 |
| - Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT) |
1714 |
| - |
1715 |
| - @staticmethod |
1716 |
| - def _make_image(file_name, size): |
1717 |
| - os.makedirs(os.path.dirname(file_name), exist_ok=True) |
1718 |
| - PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name) |
1719 |
| - |
1720 |
| - @staticmethod |
1721 |
| - def _make_devkit_archive(root, split): |
1722 |
| - Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES) |
1723 |
| - Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split]) |
1724 |
| - |
1725 |
| - @staticmethod |
1726 |
| - def _make_images_archive(root, split, small): |
1727 |
| - folder_name = Places365TestCase._IMAGES[(split, small)] |
1728 |
| - image_size = (256, 256) if small else (512, random.randint(512, 1024)) |
1729 |
| - files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT) |
1730 |
| - images = [f.lstrip("/").replace("/", os.sep) for f in files] |
1731 |
| - for image in images: |
1732 |
| - Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size) |
1733 |
| - |
1734 |
| - return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] |
1735 |
| - |
1736 |
| - def inject_fake_data(self, tmpdir, config): |
1737 |
| - self._make_devkit_archive(tmpdir, config['split']) |
1738 |
| - return len(self._make_images_archive(tmpdir, config['split'], config['small'])) |
1739 |
| - |
1740 |
| - def test_classes(self): |
1741 |
| - classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) |
1742 |
| - with self.create_dataset() as (dataset, _): |
1743 |
| - self.assertEqual(dataset.classes, classes) |
1744 |
| - |
1745 |
| - def test_class_to_idx(self): |
1746 |
| - class_to_idx = dict(self._CATEGORIES_CONTENT) |
1747 |
| - with self.create_dataset() as (dataset, _): |
1748 |
| - self.assertEqual(dataset.class_to_idx, class_to_idx) |
1749 |
| - |
1750 |
| - def test_images_download_preexisting(self): |
1751 |
| - with self.assertRaises(RuntimeError): |
1752 |
| - with self.create_dataset({'download': True}): |
1753 |
| - pass |
1754 |
| - |
1755 |
| - |
1756 | 1766 | if __name__ == "__main__":
|
1757 | 1767 | unittest.main()
|
0 commit comments