Skip to content

Ported places365 dataset's tests to the new test framework #3705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 23, 2021
100 changes: 0 additions & 100 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,103 +208,3 @@ def _make_annotations_archive(root):
_make_annotations_archive(root_base)

yield root


@contextlib.contextmanager
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
"val": "standard",
}
# {split: file}
DEVKITS = {
"train-standard": "filelist_places365-standard.tar",
"train-challenge": "filelist_places365-challenge.tar",
"val": "filelist_places365-standard.tar",
}
CATEGORIES = "categories_places365.txt"
# {split: file}
FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_train_standard.txt",
}
# {(split, small): (archive, folder_default, folder_renamed)}
IMAGES = {
("train-standard", False): ("train_large_places365standard.tar", "data_large", "data_large_standard"),
("train-challenge", False): ("train_large_places365challenge.tar", "data_large", "data_large_challenge"),
("val", False): ("val_large.tar", "val_large", "val_large"),
("train-standard", True): ("train_256_places365standard.tar", "data_256", "data_256_standard"),
("train-challenge", True): ("train_256_places365challenge.tar", "data_256", "data_256_challenge"),
("val", True): ("val_256.tar", "val_256", "val_256"),
}

# (class, idx)
CATEGORIES_CONTENT = (("/a/airfield", 0), ("/a/apartment_building/outdoor", 8), ("/b/badlands", 30))
# (file, idx)
FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) for category, idx in CATEGORIES_CONTENT),
)

def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"

def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for string, idx in seq:
fh.write(f"{string} {idx}\n")
return name, compute_md5(file)

def make_categories_txt(root, name):
return make_txt(root, name, CATEGORIES_CONTENT)

def make_file_list_txt(root, name):
return make_txt(root, name, FILE_LIST_CONTENT)

def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)

def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []

meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])

meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])

meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)

def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]

image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*FILE_LIST_CONTENT)
images = [file.lstrip("/").replace("/", os.sep) for file in files]
for image in images:
make_image(os.path.join(root, folder_default, image), image_size)

meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)

return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())

data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)

clean_dir(root, ".tar$")

yield root, data
192 changes: 91 additions & 101 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -41,106 +40,6 @@
HAS_PYAV = False


class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))

def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))


class Tester(DatasetTestcase):
def test_places365(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with places365_root(split=split, small=small) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))

def test_places365_transforms(self):
expected_image = "image"
expected_target = "target"

def transform(image):
return expected_image

def target_transform(target):
return expected_target

with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(
root, transform=transform, target_transform=target_transform, download=True
)
actual_image, actual_target = dataset[0]

self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)

def test_places365_devkit_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, download=True)

with self.subTest("classes"):
self.assertSequenceEqual(dataset.classes, data["classes"])

with self.subTest("class_to_idx"):
self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

with self.subTest("imgs"):
self.assertSequenceEqual(dataset.imgs, data["imgs"])

def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split) as places365:
root, data = places365

with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, download=False)

def test_places365_images_download(self):
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
with self.subTest(split=split, small=small):
with places365_root(split=split, small=small) as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

assert all(os.path.exists(item[0]) for item in dataset.imgs)

def test_places365_images_download_preexisting(self):
split = "train-standard"
small = False
images_dir = "data_large_standard"

with places365_root(split=split, small=small) as places365:
root, data = places365
os.mkdir(os.path.join(root, images_dir))

with self.assertRaises(RuntimeError):
torchvision.datasets.Places365(root, split=split, small=small, download=True)

def test_places365_repr_smoke(self):
with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
Expand Down Expand Up @@ -1763,5 +1662,96 @@ def inject_fake_data(self, tmpdir, config):
return num_examples


class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train-standard", "train-challenge", "val"),
small=(False, True),
)
_CATEGORIES = "categories_places365.txt"
# {split: file}
_FILE_LISTS = {
"train-standard": "places365_train_standard.txt",
"train-challenge": "places365_train_challenge.txt",
"val": "places365_val.txt",
}
# {(split, small): folder_name}
_IMAGES = {
("train-standard", False): "data_large_standard",
("train-challenge", False): "data_large_challenge",
("val", False): "val_large",
("train-standard", True): "data_256_standard",
("train-challenge", True): "data_256_challenge",
("val", True): "val_256",
}
# (class, idx)
_CATEGORIES_CONTENT = (
("/a/airfield", 0),
("/a/apartment_building/outdoor", 8),
("/b/badlands", 30),
)
# (file, idx)
_FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx)
for category, idx in _CATEGORIES_CONTENT),
)

@staticmethod
def _make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
for text, idx in seq:
fh.write(f"{text} {idx}\n")

@staticmethod
def _make_categories_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT)

@staticmethod
def _make_file_list_txt(root, name):
Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT)

@staticmethod
def _make_image(file_name, size):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name)

@staticmethod
def _make_devkit_archive(root, split):
Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES)
Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split])

@staticmethod
def _make_images_archive(root, split, small):
folder_name = Places365TestCase._IMAGES[(split, small)]
image_size = (256, 256) if small else (512, random.randint(512, 1024))
files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT)
images = [f.lstrip("/").replace("/", os.sep) for f in files]
for image in images:
Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size)

return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]

def inject_fake_data(self, tmpdir, config):
self._make_devkit_archive(tmpdir, config['split'])
return len(self._make_images_archive(tmpdir, config['split'], config['small']))

def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)

def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)

def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with self.create_dataset({'download': True}):
pass


if __name__ == "__main__":
unittest.main()