Skip to content

Commit d419558

Browse files
authored
Revert "Ported places365 dataset's tests to the new test framework (#3705)" (#3718)
This reverts commit 4b0b332.
1 parent 7be02cb commit d419558

File tree

2 files changed

+201
-91
lines changed

2 files changed

+201
-91
lines changed

test/fakedata_generation.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,103 @@ def _make_annotations_archive(root):
208208
_make_annotations_archive(root_base)
209209

210210
yield root
211+
212+
213+
@contextlib.contextmanager
214+
def places365_root(split="train-standard", small=False):
215+
VARIANTS = {
216+
"train-standard": "standard",
217+
"train-challenge": "challenge",
218+
"val": "standard",
219+
}
220+
# {split: file}
221+
DEVKITS = {
222+
"train-standard": "filelist_places365-standard.tar",
223+
"train-challenge": "filelist_places365-challenge.tar",
224+
"val": "filelist_places365-standard.tar",
225+
}
226+
CATEGORIES = "categories_places365.txt"
227+
# {split: file}
228+
FILE_LISTS = {
229+
"train-standard": "places365_train_standard.txt",
230+
"train-challenge": "places365_train_challenge.txt",
231+
"val": "places365_train_standard.txt",
232+
}
233+
# {(split, small): (archive, folder_default, folder_renamed)}
234+
IMAGES = {
235+
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
236+
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
237+
("val", False): ("val_large.tar", "val_large", "val_large"),
238+
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
239+
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
240+
("val", True): ("val_256.tar", "val_256", "val_256"),
241+
}
242+
243+
# (class, idx)
244+
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
245+
# (file, idx)
246+
FILE_LIST_CONTENT = (
247+
("Places365_val_00000001.png", 0),
248+
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
249+
)
250+
251+
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
252+
return f"{partial}.{attr}"
253+
254+
def make_txt(root, name, seq):
255+
file = os.path.join(root, name)
256+
with open(file, "w") as fh:
257+
for string, idx in seq:
258+
fh.write(f"{string} {idx}\n")
259+
return name, compute_md5(file)
260+
261+
def make_categories_txt(root, name):
262+
return make_txt(root, name, CATEGORIES_CONTENT)
263+
264+
def make_file_list_txt(root, name):
265+
return make_txt(root, name, FILE_LIST_CONTENT)
266+
267+
def make_image(file, size):
268+
os.makedirs(os.path.dirname(file), exist_ok=True)
269+
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
270+
271+
def make_devkit_archive(stack, root, split):
272+
archive = DEVKITS[split]
273+
files = []
274+
275+
meta = make_categories_txt(root, CATEGORIES)
276+
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
277+
files.append(meta[0])
278+
279+
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
280+
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
281+
files.extend([item[0] for item in meta.values()])
282+
283+
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
284+
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
285+
286+
def make_images_archive(stack, root, split, small):
287+
archive, folder_default, folder_renamed = IMAGES[(split, small)]
288+
289+
image_size = (256, 256) if small else (512, random.randint(512, 1024))
290+
files, idcs = zip(*FILE_LIST_CONTENT)
291+
images = [file.lstrip("/").replace("/", os.sep) for file in files]
292+
for image in images:
293+
make_image(os.path.join(root, folder_default, image), image_size)
294+
295+
meta = {(split, small): make_tar(root, archive, folder_default)}
296+
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
297+
298+
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
299+
300+
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
301+
make_devkit_archive(stack, root, split)
302+
class_to_idx = dict(CATEGORIES_CONTENT)
303+
classes = list(class_to_idx.keys())
304+
305+
data = {"class_to_idx": class_to_idx, "classes": classes}
306+
data["imgs"] = make_images_archive(stack, root, split, small)
307+
308+
clean_dir(root, ".tar$")
309+
310+
yield root, data

test/test_datasets.py

Lines changed: 101 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torchvision
1010
from torchvision.datasets import utils
1111
from common_utils import get_tmp_dir
12+
from fakedata_generation import places365_root
1213
import xml.etree.ElementTree as ET
1314
from urllib.request import Request, urlopen
1415
import itertools
@@ -40,6 +41,106 @@
4041
HAS_PYAV = False
4142

4243

44+
class DatasetTestcase(unittest.TestCase):
45+
def generic_classification_dataset_test(self, dataset, num_images=1):
46+
self.assertEqual(len(dataset), num_images)
47+
img, target = dataset[0]
48+
self.assertTrue(isinstance(img, PIL.Image.Image))
49+
self.assertTrue(isinstance(target, int))
50+
51+
def generic_segmentation_dataset_test(self, dataset, num_images=1):
52+
self.assertEqual(len(dataset), num_images)
53+
img, target = dataset[0]
54+
self.assertTrue(isinstance(img, PIL.Image.Image))
55+
self.assertTrue(isinstance(target, PIL.Image.Image))
56+
57+
58+
class Tester(DatasetTestcase):
59+
def test_places365(self):
60+
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
61+
with places365_root(split=split, small=small) as places365:
62+
root, data = places365
63+
64+
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
65+
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))
66+
67+
def test_places365_transforms(self):
68+
expected_image = "image"
69+
expected_target = "target"
70+
71+
def transform(image):
72+
return expected_image
73+
74+
def target_transform(target):
75+
return expected_target
76+
77+
with places365_root() as places365:
78+
root, data = places365
79+
80+
dataset = torchvision.datasets.Places365(
81+
root, transform=transform, target_transform=target_transform, download=True
82+
)
83+
actual_image, actual_target = dataset[0]
84+
85+
self.assertEqual(actual_image, expected_image)
86+
self.assertEqual(actual_target, expected_target)
87+
88+
def test_places365_devkit_download(self):
89+
for split in ("train-standard", "train-challenge", "val"):
90+
with self.subTest(split=split):
91+
with places365_root(split=split) as places365:
92+
root, data = places365
93+
94+
dataset = torchvision.datasets.Places365(root, split=split, download=True)
95+
96+
with self.subTest("classes"):
97+
self.assertSequenceEqual(dataset.classes, data["classes"])
98+
99+
with self.subTest("class_to_idx"):
100+
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])
101+
102+
with self.subTest("imgs"):
103+
self.assertSequenceEqual(dataset.imgs, data["imgs"])
104+
105+
def test_places365_devkit_no_download(self):
106+
for split in ("train-standard", "train-challenge", "val"):
107+
with self.subTest(split=split):
108+
with places365_root(split=split) as places365:
109+
root, data = places365
110+
111+
with self.assertRaises(RuntimeError):
112+
torchvision.datasets.Places365(root, split=split, download=False)
113+
114+
def test_places365_images_download(self):
115+
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
116+
with self.subTest(split=split, small=small):
117+
with places365_root(split=split, small=small) as places365:
118+
root, data = places365
119+
120+
dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
121+
122+
assert all(os.path.exists(item[0]) for item in dataset.imgs)
123+
124+
def test_places365_images_download_preexisting(self):
125+
split = "train-standard"
126+
small = False
127+
images_dir = "data_large_standard"
128+
129+
with places365_root(split=split, small=small) as places365:
130+
root, data = places365
131+
os.mkdir(os.path.join(root, images_dir))
132+
133+
with self.assertRaises(RuntimeError):
134+
torchvision.datasets.Places365(root, split=split, small=small, download=True)
135+
136+
def test_places365_repr_smoke(self):
137+
with places365_root() as places365:
138+
root, data = places365
139+
140+
dataset = torchvision.datasets.Places365(root, download=True)
141+
self.assertIsInstance(repr(dataset), str)
142+
143+
43144
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
44145
DATASET_CLASS = datasets.STL10
45146
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
@@ -1662,96 +1763,5 @@ def inject_fake_data(self, tmpdir, config):
16621763
return num_examples
16631764

16641765

1665-
class Places365TestCase(datasets_utils.ImageDatasetTestCase):
1666-
DATASET_CLASS = datasets.Places365
1667-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
1668-
split=("train-standard", "train-challenge", "val"),
1669-
small=(False, True),
1670-
)
1671-
_CATEGORIES = "categories_places365.txt"
1672-
# {split: file}
1673-
_FILE_LISTS = {
1674-
"train-standard": "places365_train_standard.txt",
1675-
"train-challenge": "places365_train_challenge.txt",
1676-
"val": "places365_val.txt",
1677-
}
1678-
# {(split, small): folder_name}
1679-
_IMAGES = {
1680-
("train-standard", False): "data_large_standard",
1681-
("train-challenge", False): "data_large_challenge",
1682-
("val", False): "val_large",
1683-
("train-standard", True): "data_256_standard",
1684-
("train-challenge", True): "data_256_challenge",
1685-
("val", True): "val_256",
1686-
}
1687-
# (class, idx)
1688-
_CATEGORIES_CONTENT = (
1689-
("/a/airfield", 0),
1690-
("/a/apartment_building/outdoor", 8),
1691-
("/b/badlands", 30),
1692-
)
1693-
# (file, idx)
1694-
_FILE_LIST_CONTENT = (
1695-
("Places365_val_00000001.png", 0),
1696-
*((f"{category}/Places365_train_00000001.png", idx)
1697-
for category, idx in _CATEGORIES_CONTENT),
1698-
)
1699-
1700-
@staticmethod
1701-
def _make_txt(root, name, seq):
1702-
file = os.path.join(root, name)
1703-
with open(file, "w") as fh:
1704-
for text, idx in seq:
1705-
fh.write(f"{text} {idx}\n")
1706-
1707-
@staticmethod
1708-
def _make_categories_txt(root, name):
1709-
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)
1710-
1711-
@staticmethod
1712-
def _make_file_list_txt(root, name):
1713-
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)
1714-
1715-
@staticmethod
1716-
def _make_image(file_name, size):
1717-
os.makedirs(os.path.dirname(file_name), exist_ok=True)
1718-
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)
1719-
1720-
@staticmethod
1721-
def _make_devkit_archive(root, split):
1722-
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
1723-
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])
1724-
1725-
@staticmethod
1726-
def _make_images_archive(root, split, small):
1727-
folder_name = Places365TestCase._IMAGES[(split, small)]
1728-
image_size = (256, 256) if small else (512, random.randint(512, 1024))
1729-
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
1730-
images = [f.lstrip("/").replace("/", os.sep) for f in files]
1731-
for image in images:
1732-
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)
1733-
1734-
return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]
1735-
1736-
def inject_fake_data(self, tmpdir, config):
1737-
self._make_devkit_archive(tmpdir, config['split'])
1738-
return len(self._make_images_archive(tmpdir, config['split'], config['small']))
1739-
1740-
def test_classes(self):
1741-
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
1742-
with self.create_dataset() as (dataset, _):
1743-
self.assertEqual(dataset.classes, classes)
1744-
1745-
def test_class_to_idx(self):
1746-
class_to_idx = dict(self._CATEGORIES_CONTENT)
1747-
with self.create_dataset() as (dataset, _):
1748-
self.assertEqual(dataset.class_to_idx, class_to_idx)
1749-
1750-
def test_images_download_preexisting(self):
1751-
with self.assertRaises(RuntimeError):
1752-
with self.create_dataset({'download': True}):
1753-
pass
1754-
1755-
17561766
if __name__ == "__main__":
17571767
unittest.main()

0 commit comments

Comments
 (0)