Skip to content

Commit 3c81d47

Browse files
Philip Meierfmassa
authored andcommitted
Add a generic test for the datasets (#1015)
* added a generic test for the datasets * addressed requested changes - renamed generic*() to generic_classification*() - moved function inside Tester - test class_to_idx attribute outside of generic_classification*()
1 parent 250bac8 commit 3c81d47

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

test/fakedata_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _make_image_file(filename, num_images):
2929
f.write(img.numpy().tobytes())
3030

3131
def _make_label_file(filename, num_images):
32-
labels = torch.randint(0, 10, size=(num_images,), dtype=torch.uint8)
32+
labels = torch.zeros((num_images,), dtype=torch.uint8)
3333
with open(filename, "wb") as f:
3434
f.write(_encode(2049)) # magic header
3535
f.write(_encode(num_images))

test/test_datasets.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111

1212
class Tester(unittest.TestCase):
13+
def generic_classification_dataset_test(self, dataset, num_images=1):
14+
self.assertEqual(len(dataset), num_images)
15+
img, target = dataset[0]
16+
self.assertTrue(isinstance(img, PIL.Image.Image))
17+
self.assertTrue(isinstance(target, int))
18+
1319
def test_imagefolder(self):
1420
# TODO: create the fake data on-the-fly
1521
FAKEDATA_DIR = get_file_path_2(
@@ -64,47 +70,36 @@ def test_mnist(self, mock_download_extract):
6470
num_examples = 30
6571
with mnist_root(num_examples, "MNIST") as root:
6672
dataset = torchvision.datasets.MNIST(root, download=True)
67-
self.assertEqual(len(dataset), num_examples)
73+
self.generic_classification_dataset_test(dataset, num_images=num_examples)
6874
img, target = dataset[0]
69-
self.assertTrue(isinstance(img, PIL.Image.Image))
70-
self.assertTrue(isinstance(target, int))
75+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
7176

7277
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
7378
def test_kmnist(self, mock_download_extract):
7479
num_examples = 30
7580
with mnist_root(num_examples, "KMNIST") as root:
7681
dataset = torchvision.datasets.KMNIST(root, download=True)
82+
self.generic_classification_dataset_test(dataset, num_images=num_examples)
7783
img, target = dataset[0]
78-
self.assertEqual(len(dataset), num_examples)
79-
self.assertTrue(isinstance(img, PIL.Image.Image))
80-
self.assertTrue(isinstance(target, int))
84+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
8185

8286
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
8387
def test_fashionmnist(self, mock_download_extract):
8488
num_examples = 30
8589
with mnist_root(num_examples, "FashionMNIST") as root:
8690
dataset = torchvision.datasets.FashionMNIST(root, download=True)
91+
self.generic_classification_dataset_test(dataset, num_images=num_examples)
8792
img, target = dataset[0]
88-
self.assertEqual(len(dataset), num_examples)
89-
self.assertTrue(isinstance(img, PIL.Image.Image))
90-
self.assertTrue(isinstance(target, int))
93+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
9194

9295
@mock.patch('torchvision.datasets.utils.download_url')
9396
def test_imagenet(self, mock_download):
9497
with imagenet_root() as root:
9598
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
96-
self.assertEqual(len(dataset), 1)
97-
img, target = dataset[0]
98-
self.assertTrue(isinstance(img, PIL.Image.Image))
99-
self.assertTrue(isinstance(target, int))
100-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
99+
self.generic_classification_dataset_test(dataset)
101100

102101
dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
103-
self.assertEqual(len(dataset), 1)
104-
img, target = dataset[0]
105-
self.assertTrue(isinstance(img, PIL.Image.Image))
106-
self.assertTrue(isinstance(target, int))
107-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
102+
self.generic_classification_dataset_test(dataset)
108103

109104
@mock.patch('torchvision.datasets.cifar.check_integrity')
110105
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
@@ -113,18 +108,14 @@ def test_cifar10(self, mock_ext_check, mock_int_check):
113108
mock_int_check.return_value = True
114109
with cifar_root('CIFAR10') as root:
115110
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
116-
self.assertEqual(len(dataset), 5)
111+
self.generic_classification_dataset_test(dataset, num_images=5)
117112
img, target = dataset[0]
118-
self.assertTrue(isinstance(img, PIL.Image.Image))
119-
self.assertTrue(isinstance(target, int))
120-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
113+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
121114

122115
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
123-
self.assertEqual(len(dataset), 1)
116+
self.generic_classification_dataset_test(dataset)
124117
img, target = dataset[0]
125-
self.assertTrue(isinstance(img, PIL.Image.Image))
126-
self.assertTrue(isinstance(target, int))
127-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
118+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
128119

129120
@mock.patch('torchvision.datasets.cifar.check_integrity')
130121
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
@@ -133,18 +124,14 @@ def test_cifar100(self, mock_ext_check, mock_int_check):
133124
mock_int_check.return_value = True
134125
with cifar_root('CIFAR100') as root:
135126
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
136-
self.assertEqual(len(dataset), 1)
127+
self.generic_classification_dataset_test(dataset)
137128
img, target = dataset[0]
138-
self.assertTrue(isinstance(img, PIL.Image.Image))
139-
self.assertTrue(isinstance(target, int))
140-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
129+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
141130

142131
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
143-
self.assertEqual(len(dataset), 1)
132+
self.generic_classification_dataset_test(dataset)
144133
img, target = dataset[0]
145-
self.assertTrue(isinstance(img, PIL.Image.Image))
146-
self.assertTrue(isinstance(target, int))
147-
self.assertEqual(dataset.class_to_idx['fakedata'], target)
134+
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
148135

149136

150137
if __name__ == '__main__':

0 commit comments

Comments
 (0)