Skip to content

Commit a89da92

Browse files
authored
Ported STL10 dataset's tests to new test framework (#3665)
* Ported STL10 dataset's tests to new test framework * Added additional tests * Removed unused import * Made private methods static and other minor changes
1 parent a18b4af commit a89da92

File tree

2 files changed

+69
-135
lines changed

2 files changed

+69
-135
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -308,82 +308,3 @@ def make_images_archive(stack, root, split, small):
308308
clean_dir(root, ".tar$")
309309

310310
yield root, data
311-
312-
313-
@contextlib.contextmanager
314-
def stl10_root(_extracted=False):
315-
CLASS_NAMES = ("airplane", "bird")
316-
ARCHIVE_NAME = "stl10_binary"
317-
NUM_FOLDS = 10
318-
319-
def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
320-
return f"{partial}.{attr}"
321-
322-
def make_binary_file(num_elements, root, name):
323-
file = os.path.join(root, name)
324-
np.zeros(num_elements, dtype=np.uint8).tofile(file)
325-
return name, compute_md5(file)
326-
327-
def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
328-
return make_binary_file(num_images * num_channels * height * width, root, name)
329-
330-
def make_label_file(num_images, root, name):
331-
return make_binary_file(num_images, root, name)
332-
333-
def make_class_names_file(root, name="class_names.txt"):
334-
with open(os.path.join(root, name), "w") as fh:
335-
for name in CLASS_NAMES:
336-
fh.write(f"{name}\n")
337-
338-
def make_fold_indices_file(root):
339-
offset = 0
340-
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
341-
for fold in range(NUM_FOLDS):
342-
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
343-
fh.write(f"{line}\n")
344-
offset += fold + 1
345-
346-
return tuple(range(1, NUM_FOLDS + 1))
347-
348-
def make_train_files(stack, root, num_unlabeled_images=1):
349-
num_images_in_fold = make_fold_indices_file(root)
350-
num_train_images = sum(num_images_in_fold)
351-
352-
train_list = [
353-
list(make_image_file(num_train_images, root, "train_X.bin")),
354-
list(make_label_file(num_train_images, root, "train_y.bin")),
355-
list(make_image_file(1, root, "unlabeled_X.bin"))
356-
]
357-
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)
358-
359-
return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)
360-
361-
def make_test_files(stack, root, num_images=2):
362-
test_list = [
363-
list(make_image_file(num_images, root, "test_X.bin")),
364-
list(make_label_file(num_images, root, "test_y.bin")),
365-
]
366-
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)
367-
368-
return dict(test=num_images)
369-
370-
def make_archive(stack, root, name):
371-
archive, md5 = make_tar(root, name, name, compression="gz")
372-
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
373-
return archive
374-
375-
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
376-
archive_folder = os.path.join(root, ARCHIVE_NAME)
377-
os.mkdir(archive_folder)
378-
379-
num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
380-
num_images_in_split.update(make_test_files(stack, archive_folder))
381-
382-
make_class_names_file(archive_folder)
383-
384-
archive = make_archive(stack, root, ARCHIVE_NAME)
385-
386-
dir_util.remove_tree(archive_folder)
387-
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)
388-
389-
yield root, data

test/test_datasets.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
from torchvision.datasets import utils
1111
from common_utils import get_tmp_dir
12-
from fakedata_generation import places365_root, widerface_root, stl10_root
12+
from fakedata_generation import places365_root
1313
import xml.etree.ElementTree as ET
1414
from urllib.request import Request, urlopen
1515
import itertools
@@ -141,76 +141,89 @@ def test_places365_repr_smoke(self):
141141
self.assertIsInstance(repr(dataset), str)
142142

143143

144-
class STL10Tester(DatasetTestcase):
145-
@contextlib.contextmanager
146-
def mocked_root(self):
147-
with stl10_root() as (root, data):
148-
yield root, data
144+
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
145+
DATASET_CLASS = datasets.STL10
146+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
147+
split=("train", "test", "unlabeled", "train+unlabeled"))
149148

150-
@contextlib.contextmanager
151-
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
152-
with self.mocked_root() as (root, data):
153-
if pre_extract:
154-
utils.extract_archive(os.path.join(root, data["archive"]))
155-
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
156-
yield dataset, data
157-
158-
def test_not_found(self):
159-
with self.assertRaises(RuntimeError):
160-
with self.mocked_dataset(download=False):
161-
pass
149+
@staticmethod
150+
def _make_binary_file(num_elements, root, name):
151+
file_name = os.path.join(root, name)
152+
np.zeros(num_elements, dtype=np.uint8).tofile(file_name)
162153

163-
def test_splits(self):
164-
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
165-
with self.mocked_dataset(split=split) as (dataset, data):
166-
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
167-
self.generic_classification_dataset_test(dataset, num_images=num_images)
154+
@staticmethod
155+
def _make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
156+
STL10TestCase._make_binary_file(num_images * num_channels * height * width, root, name)
168157

169-
def test_folds(self):
170-
for fold in range(10):
171-
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
172-
num_images = data["num_images_in_folds"][fold]
173-
self.assertEqual(len(dataset), num_images)
158+
@staticmethod
159+
def _make_label_file(num_images, root, name):
160+
STL10TestCase._make_binary_file(num_images, root, name)
174161

175-
def test_invalid_folds1(self):
176-
with self.assertRaises(ValueError):
177-
with self.mocked_dataset(folds=10):
178-
pass
162+
@staticmethod
163+
def _make_class_names_file(root, name="class_names.txt"):
164+
with open(os.path.join(root, name), "w") as fh:
165+
for cname in ("airplane", "bird"):
166+
fh.write(f"{cname}\n")
179167

180-
def test_invalid_folds2(self):
181-
with self.assertRaises(ValueError):
182-
with self.mocked_dataset(folds="0"):
183-
pass
168+
@staticmethod
169+
def _make_fold_indices_file(root):
170+
num_folds = 10
171+
offset = 0
172+
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
173+
for fold in range(num_folds):
174+
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
175+
fh.write(f"{line}\n")
176+
offset += fold + 1
184177

185-
def test_transforms(self):
186-
expected_image = "image"
187-
expected_target = "target"
178+
return tuple(range(1, num_folds + 1))
188179

189-
def transform(image):
190-
return expected_image
180+
@staticmethod
181+
def _make_train_files(root, num_unlabeled_images=1):
182+
num_images_in_fold = STL10TestCase._make_fold_indices_file(root)
183+
num_train_images = sum(num_images_in_fold)
191184

192-
def target_transform(target):
193-
return expected_target
185+
STL10TestCase._make_image_file(num_train_images, root, "train_X.bin")
186+
STL10TestCase._make_label_file(num_train_images, root, "train_y.bin")
187+
STL10TestCase._make_image_file(1, root, "unlabeled_X.bin")
194188

195-
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
196-
actual_image, actual_target = dataset[0]
189+
return dict(train=num_train_images, unlabeled=num_unlabeled_images)
197190

198-
self.assertEqual(actual_image, expected_image)
199-
self.assertEqual(actual_target, expected_target)
191+
@staticmethod
192+
def _make_test_files(root, num_images=2):
193+
STL10TestCase._make_image_file(num_images, root, "test_X.bin")
194+
STL10TestCase._make_label_file(num_images, root, "test_y.bin")
195+
196+
return dict(test=num_images)
197+
198+
def inject_fake_data(self, tmpdir, config):
199+
root_folder = os.path.join(tmpdir, "stl10_binary")
200+
os.mkdir(root_folder)
201+
202+
num_images_in_split = self._make_train_files(root_folder)
203+
num_images_in_split.update(self._make_test_files(root_folder))
204+
self._make_class_names_file(root_folder)
205+
206+
return sum(num_images_in_split[part] for part in config["split"].split("+"))
207+
208+
def test_folds(self):
209+
for fold in range(10):
210+
with self.create_dataset(split="train", folds=fold) as (dataset, _):
211+
self.assertEqual(len(dataset), fold + 1)
200212

201213
def test_unlabeled(self):
202-
with self.mocked_dataset(split="unlabeled") as (dataset, _):
214+
with self.create_dataset(split="unlabeled") as (dataset, _):
203215
labels = [dataset[idx][1] for idx in range(len(dataset))]
204-
self.assertTrue(all([label == -1 for label in labels]))
216+
self.assertTrue(all(label == -1 for label in labels))
205217

206-
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
207-
def test_download_preexisting(self, mock):
208-
with self.mocked_dataset(pre_extract=True) as (dataset, data):
209-
mock.assert_not_called()
218+
def test_invalid_folds1(self):
219+
with self.assertRaises(ValueError):
220+
with self.create_dataset(folds=10):
221+
pass
210222

211-
def test_repr_smoke(self):
212-
with self.mocked_dataset() as (dataset, _):
213-
self.assertIsInstance(repr(dataset), str)
223+
def test_invalid_folds2(self):
224+
with self.assertRaises(ValueError):
225+
with self.create_dataset(folds="0"):
226+
pass
214227

215228

216229
class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):

0 commit comments

Comments
 (0)