Skip to content

Ported STL10 dataset's tests to new test framework #3665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,82 +308,3 @@ def make_images_archive(stack, root, split, small):
clean_dir(root, ".tar$")

yield root, data


@contextlib.contextmanager
def stl10_root(_extracted=False):
CLASS_NAMES = ("airplane", "bird")
ARCHIVE_NAME = "stl10_binary"
NUM_FOLDS = 10

def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
return f"{partial}.{attr}"

def make_binary_file(num_elements, root, name):
file = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file)
return name, compute_md5(file)

def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
return make_binary_file(num_images * num_channels * height * width, root, name)

def make_label_file(num_images, root, name):
return make_binary_file(num_images, root, name)

def make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for name in CLASS_NAMES:
fh.write(f"{name}\n")

def make_fold_indices_file(root):
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(NUM_FOLDS):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1

return tuple(range(1, NUM_FOLDS + 1))

def make_train_files(stack, root, num_unlabeled_images=1):
num_images_in_fold = make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)

train_list = [
list(make_image_file(num_train_images, root, "train_X.bin")),
list(make_label_file(num_train_images, root, "train_y.bin")),
list(make_image_file(1, root, "unlabeled_X.bin"))
]
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)

return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)

def make_test_files(stack, root, num_images=2):
test_list = [
list(make_image_file(num_images, root, "test_X.bin")),
list(make_label_file(num_images, root, "test_y.bin")),
]
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)

return dict(test=num_images)

def make_archive(stack, root, name):
archive, md5 = make_tar(root, name, name, compression="gz")
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
return archive

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
archive_folder = os.path.join(root, ARCHIVE_NAME)
os.mkdir(archive_folder)

num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
num_images_in_split.update(make_test_files(stack, archive_folder))

make_class_names_file(archive_folder)

archive = make_archive(stack, root, ARCHIVE_NAME)

dir_util.remove_tree(archive_folder)
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)

yield root, data
125 changes: 69 additions & 56 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import places365_root, widerface_root, stl10_root
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -141,76 +141,89 @@ def test_places365_repr_smoke(self):
self.assertIsInstance(repr(dataset), str)


class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test", "unlabeled", "train+unlabeled"))

@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data

def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass
@staticmethod
def _make_binary_file(num_elements, root, name):
file_name = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file_name)

def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)
@staticmethod
def _make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
STL10TestCase._make_binary_file(num_images * num_channels * height * width, root, name)

def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)
@staticmethod
def _make_label_file(num_images, root, name):
STL10TestCase._make_binary_file(num_images, root, name)

def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass
@staticmethod
def _make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for cname in ("airplane", "bird"):
fh.write(f"{cname}\n")

def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass
@staticmethod
def _make_fold_indices_file(root):
num_folds = 10
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(num_folds):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1

def test_transforms(self):
expected_image = "image"
expected_target = "target"
return tuple(range(1, num_folds + 1))

def transform(image):
return expected_image
@staticmethod
def _make_train_files(root, num_unlabeled_images=1):
num_images_in_fold = STL10TestCase._make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)

def target_transform(target):
return expected_target
STL10TestCase._make_image_file(num_train_images, root, "train_X.bin")
STL10TestCase._make_label_file(num_train_images, root, "train_y.bin")
STL10TestCase._make_image_file(1, root, "unlabeled_X.bin")

with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]
return dict(train=num_train_images, unlabeled=num_unlabeled_images)

self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
@staticmethod
def _make_test_files(root, num_images=2):
STL10TestCase._make_image_file(num_images, root, "test_X.bin")
STL10TestCase._make_label_file(num_images, root, "test_y.bin")

return dict(test=num_images)

def inject_fake_data(self, tmpdir, config):
root_folder = os.path.join(tmpdir, "stl10_binary")
os.mkdir(root_folder)

num_images_in_split = self._make_train_files(root_folder)
num_images_in_split.update(self._make_test_files(root_folder))
self._make_class_names_file(root_folder)

return sum(num_images_in_split[part] for part in config["split"].split("+"))

def test_folds(self):
for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1)

def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))
self.assertTrue(all(label == -1 for label in labels))

@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.create_dataset(folds=10):
pass

def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.create_dataset(folds="0"):
pass


class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down