|
9 | 9 | import torchvision
|
10 | 10 | from torchvision.datasets import utils
|
11 | 11 | from common_utils import get_tmp_dir
|
12 |
| -from fakedata_generation import places365_root, widerface_root, stl10_root |
| 12 | +from fakedata_generation import places365_root |
13 | 13 | import xml.etree.ElementTree as ET
|
14 | 14 | from urllib.request import Request, urlopen
|
15 | 15 | import itertools
|
@@ -141,76 +141,89 @@ def test_places365_repr_smoke(self):
|
141 | 141 | self.assertIsInstance(repr(dataset), str)
|
142 | 142 |
|
143 | 143 |
|
144 |
| -class STL10Tester(DatasetTestcase): |
145 |
| - @contextlib.contextmanager |
146 |
| - def mocked_root(self): |
147 |
| - with stl10_root() as (root, data): |
148 |
| - yield root, data |
| 144 | +class STL10TestCase(datasets_utils.ImageDatasetTestCase): |
| 145 | + DATASET_CLASS = datasets.STL10 |
| 146 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( |
| 147 | + split=("train", "test", "unlabeled", "train+unlabeled")) |
149 | 148 |
|
150 |
| - @contextlib.contextmanager |
151 |
| - def mocked_dataset(self, pre_extract=False, download=True, **kwargs): |
152 |
| - with self.mocked_root() as (root, data): |
153 |
| - if pre_extract: |
154 |
| - utils.extract_archive(os.path.join(root, data["archive"])) |
155 |
| - dataset = torchvision.datasets.STL10(root, download=download, **kwargs) |
156 |
| - yield dataset, data |
157 |
| - |
158 |
| - def test_not_found(self): |
159 |
| - with self.assertRaises(RuntimeError): |
160 |
| - with self.mocked_dataset(download=False): |
161 |
| - pass |
| 149 | + @staticmethod |
| 150 | + def _make_binary_file(num_elements, root, name): |
| 151 | + file_name = os.path.join(root, name) |
| 152 | + np.zeros(num_elements, dtype=np.uint8).tofile(file_name) |
162 | 153 |
|
163 |
| - def test_splits(self): |
164 |
| - for split in ('train', 'train+unlabeled', 'unlabeled', 'test'): |
165 |
| - with self.mocked_dataset(split=split) as (dataset, data): |
166 |
| - num_images = sum([data["num_images_in_split"][part] for part in split.split("+")]) |
167 |
| - self.generic_classification_dataset_test(dataset, num_images=num_images) |
| 154 | + @staticmethod |
| 155 | + def _make_image_file(num_images, root, name, num_channels=3, height=96, width=96): |
| 156 | + STL10TestCase._make_binary_file(num_images * num_channels * height * width, root, name) |
168 | 157 |
|
169 |
| - def test_folds(self): |
170 |
| - for fold in range(10): |
171 |
| - with self.mocked_dataset(split="train", folds=fold) as (dataset, data): |
172 |
| - num_images = data["num_images_in_folds"][fold] |
173 |
| - self.assertEqual(len(dataset), num_images) |
| 158 | + @staticmethod |
| 159 | + def _make_label_file(num_images, root, name): |
| 160 | + STL10TestCase._make_binary_file(num_images, root, name) |
174 | 161 |
|
175 |
| - def test_invalid_folds1(self): |
176 |
| - with self.assertRaises(ValueError): |
177 |
| - with self.mocked_dataset(folds=10): |
178 |
| - pass |
| 162 | + @staticmethod |
| 163 | + def _make_class_names_file(root, name="class_names.txt"): |
| 164 | + with open(os.path.join(root, name), "w") as fh: |
| 165 | + for cname in ("airplane", "bird"): |
| 166 | + fh.write(f"{cname}\n") |
179 | 167 |
|
180 |
| - def test_invalid_folds2(self): |
181 |
| - with self.assertRaises(ValueError): |
182 |
| - with self.mocked_dataset(folds="0"): |
183 |
| - pass |
| 168 | + @staticmethod |
| 169 | + def _make_fold_indices_file(root): |
| 170 | + num_folds = 10 |
| 171 | + offset = 0 |
| 172 | + with open(os.path.join(root, "fold_indices.txt"), "w") as fh: |
| 173 | + for fold in range(num_folds): |
| 174 | + line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)]) |
| 175 | + fh.write(f"{line}\n") |
| 176 | + offset += fold + 1 |
184 | 177 |
|
185 |
| - def test_transforms(self): |
186 |
| - expected_image = "image" |
187 |
| - expected_target = "target" |
| 178 | + return tuple(range(1, num_folds + 1)) |
188 | 179 |
|
189 |
| - def transform(image): |
190 |
| - return expected_image |
| 180 | + @staticmethod |
| 181 | + def _make_train_files(root, num_unlabeled_images=1): |
| 182 | + num_images_in_fold = STL10TestCase._make_fold_indices_file(root) |
| 183 | + num_train_images = sum(num_images_in_fold) |
191 | 184 |
|
192 |
| - def target_transform(target): |
193 |
| - return expected_target |
| 185 | + STL10TestCase._make_image_file(num_train_images, root, "train_X.bin") |
| 186 | + STL10TestCase._make_label_file(num_train_images, root, "train_y.bin") |
| 187 | + STL10TestCase._make_image_file(1, root, "unlabeled_X.bin") |
194 | 188 |
|
195 |
| - with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _): |
196 |
| - actual_image, actual_target = dataset[0] |
| 189 | + return dict(train=num_train_images, unlabeled=num_unlabeled_images) |
197 | 190 |
|
198 |
| - self.assertEqual(actual_image, expected_image) |
199 |
| - self.assertEqual(actual_target, expected_target) |
| 191 | + @staticmethod |
| 192 | + def _make_test_files(root, num_images=2): |
| 193 | + STL10TestCase._make_image_file(num_images, root, "test_X.bin") |
| 194 | + STL10TestCase._make_label_file(num_images, root, "test_y.bin") |
| 195 | + |
| 196 | + return dict(test=num_images) |
| 197 | + |
| 198 | + def inject_fake_data(self, tmpdir, config): |
| 199 | + root_folder = os.path.join(tmpdir, "stl10_binary") |
| 200 | + os.mkdir(root_folder) |
| 201 | + |
| 202 | + num_images_in_split = self._make_train_files(root_folder) |
| 203 | + num_images_in_split.update(self._make_test_files(root_folder)) |
| 204 | + self._make_class_names_file(root_folder) |
| 205 | + |
| 206 | + return sum(num_images_in_split[part] for part in config["split"].split("+")) |
| 207 | + |
| 208 | + def test_folds(self): |
| 209 | + for fold in range(10): |
| 210 | + with self.create_dataset(split="train", folds=fold) as (dataset, _): |
| 211 | + self.assertEqual(len(dataset), fold + 1) |
200 | 212 |
|
201 | 213 | def test_unlabeled(self):
|
202 |
| - with self.mocked_dataset(split="unlabeled") as (dataset, _): |
| 214 | + with self.create_dataset(split="unlabeled") as (dataset, _): |
203 | 215 | labels = [dataset[idx][1] for idx in range(len(dataset))]
|
204 |
| - self.assertTrue(all([label == -1 for label in labels])) |
| 216 | + self.assertTrue(all(label == -1 for label in labels)) |
205 | 217 |
|
206 |
| - @unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive") |
207 |
| - def test_download_preexisting(self, mock): |
208 |
| - with self.mocked_dataset(pre_extract=True) as (dataset, data): |
209 |
| - mock.assert_not_called() |
| 218 | + def test_invalid_folds1(self): |
| 219 | + with self.assertRaises(ValueError): |
| 220 | + with self.create_dataset(folds=10): |
| 221 | + pass |
210 | 222 |
|
211 |
| - def test_repr_smoke(self): |
212 |
| - with self.mocked_dataset() as (dataset, _): |
213 |
| - self.assertIsInstance(repr(dataset), str) |
| 223 | + def test_invalid_folds2(self): |
| 224 | + with self.assertRaises(ValueError): |
| 225 | + with self.create_dataset(folds="0"): |
| 226 | + pass |
214 | 227 |
|
215 | 228 |
|
216 | 229 | class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
|
|
0 commit comments