Skip to content

Commit 512ea29

Browse files
authored
Ported SVHN dataset to new test framework (#3661)
* Ported SVHN dataset to new test framework * Fixed flake8 error and added REQUIRED_PACKAGES=scipy
1 parent 44460c9 commit 512ea29

File tree

2 files changed

+23
-33
lines changed

2 files changed

+23
-33
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -210,23 +210,6 @@ def _make_annotations_archive(root):
210210
yield root
211211

212212

213-
@contextlib.contextmanager
214-
def svhn_root():
215-
import scipy.io as sio
216-
217-
def _make_mat(file):
218-
images = np.zeros((32, 32, 3, 2), dtype=np.uint8)
219-
targets = np.zeros((2,), dtype=np.uint8)
220-
sio.savemat(file, {'X': images, 'y': targets})
221-
222-
with get_tmp_dir() as root:
223-
_make_mat(os.path.join(root, "train_32x32.mat"))
224-
_make_mat(os.path.join(root, "test_32x32.mat"))
225-
_make_mat(os.path.join(root, "extra_32x32.mat"))
226-
227-
yield root
228-
229-
230213
@contextlib.contextmanager
231214
def places365_root(split="train-standard", small=False):
232215
VARIANTS = {

test/test_datasets.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import sys
33
import os
44
import unittest
5-
from unittest import mock
65
import numpy as np
76
import PIL
87
from PIL import Image
98
from torch._utils_internal import get_file_path_2
109
import torchvision
1110
from torchvision.datasets import utils
1211
from common_utils import get_tmp_dir
13-
from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root
12+
from fakedata_generation import places365_root, widerface_root, stl10_root
1413
import xml.etree.ElementTree as ET
1514
from urllib.request import Request, urlopen
1615
import itertools
@@ -57,20 +56,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
5756

5857

5958
class Tester(DatasetTestcase):
60-
@mock.patch('torchvision.datasets.SVHN._check_integrity')
61-
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
62-
def test_svhn(self, mock_check):
63-
mock_check.return_value = True
64-
with svhn_root() as root:
65-
dataset = torchvision.datasets.SVHN(root, split="train")
66-
self.generic_classification_dataset_test(dataset, num_images=2)
67-
68-
dataset = torchvision.datasets.SVHN(root, split="test")
69-
self.generic_classification_dataset_test(dataset, num_images=2)
70-
71-
dataset = torchvision.datasets.SVHN(root, split="extra")
72-
self.generic_classification_dataset_test(dataset, num_images=2)
73-
7459
def test_places365(self):
7560
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
7661
with places365_root(split=split, small=small) as places365:
@@ -1737,5 +1722,27 @@ def inject_fake_data(self, tmpdir, config):
17371722
return split_to_num_examples[config["train"]]
17381723

17391724

1725+
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
1726+
DATASET_CLASS = datasets.SVHN
1727+
REQUIRED_PACKAGES = ("scipy",)
1728+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra"))
1729+
1730+
def inject_fake_data(self, tmpdir, config):
1731+
import scipy.io as sio
1732+
1733+
split = config["split"]
1734+
num_examples = {
1735+
"train": 2,
1736+
"test": 3,
1737+
"extra": 4,
1738+
}.get(split)
1739+
1740+
file = f"{split}_32x32.mat"
1741+
images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8)
1742+
targets = np.zeros((num_examples,), dtype=np.uint8)
1743+
sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets})
1744+
return num_examples
1745+
1746+
17401747
if __name__ == "__main__":
17411748
unittest.main()

0 commit comments

Comments
 (0)