|
2 | 2 | import sys
|
3 | 3 | import os
|
4 | 4 | import unittest
|
5 |
| -from unittest import mock |
6 | 5 | import numpy as np
|
7 | 6 | import PIL
|
8 | 7 | from PIL import Image
|
9 | 8 | from torch._utils_internal import get_file_path_2
|
10 | 9 | import torchvision
|
11 | 10 | from torchvision.datasets import utils
|
12 | 11 | from common_utils import get_tmp_dir
|
13 |
| -from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root |
| 12 | +from fakedata_generation import places365_root, widerface_root, stl10_root |
14 | 13 | import xml.etree.ElementTree as ET
|
15 | 14 | from urllib.request import Request, urlopen
|
16 | 15 | import itertools
|
@@ -57,20 +56,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
|
57 | 56 |
|
58 | 57 |
|
59 | 58 | class Tester(DatasetTestcase):
|
60 |
| - @mock.patch('torchvision.datasets.SVHN._check_integrity') |
61 |
| - @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") |
62 |
| - def test_svhn(self, mock_check): |
63 |
| - mock_check.return_value = True |
64 |
| - with svhn_root() as root: |
65 |
| - dataset = torchvision.datasets.SVHN(root, split="train") |
66 |
| - self.generic_classification_dataset_test(dataset, num_images=2) |
67 |
| - |
68 |
| - dataset = torchvision.datasets.SVHN(root, split="test") |
69 |
| - self.generic_classification_dataset_test(dataset, num_images=2) |
70 |
| - |
71 |
| - dataset = torchvision.datasets.SVHN(root, split="extra") |
72 |
| - self.generic_classification_dataset_test(dataset, num_images=2) |
73 |
| - |
74 | 59 | def test_places365(self):
|
75 | 60 | for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
|
76 | 61 | with places365_root(split=split, small=small) as places365:
|
@@ -1737,5 +1722,27 @@ def inject_fake_data(self, tmpdir, config):
|
1737 | 1722 | return split_to_num_examples[config["train"]]
|
1738 | 1723 |
|
1739 | 1724 |
|
| 1725 | +class SvhnTestCase(datasets_utils.ImageDatasetTestCase): |
| 1726 | + DATASET_CLASS = datasets.SVHN |
| 1727 | + REQUIRED_PACKAGES = ("scipy",) |
| 1728 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra")) |
| 1729 | + |
| 1730 | + def inject_fake_data(self, tmpdir, config): |
| 1731 | + import scipy.io as sio |
| 1732 | + |
| 1733 | + split = config["split"] |
| 1734 | + num_examples = { |
| 1735 | + "train": 2, |
| 1736 | + "test": 3, |
| 1737 | + "extra": 4, |
| 1738 | + }.get(split) |
| 1739 | + |
| 1740 | + file = f"{split}_32x32.mat" |
| 1741 | + images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8) |
| 1742 | + targets = np.zeros((num_examples,), dtype=np.uint8) |
| 1743 | + sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets}) |
| 1744 | + return num_examples |
| 1745 | + |
| 1746 | + |
1740 | 1747 | if __name__ == "__main__":
|
1741 | 1748 | unittest.main()
|
0 commit comments