10
10
11
11
12
12
class Tester (unittest .TestCase ):
13
+ def generic_classification_dataset_test (self , dataset , num_images = 1 ):
14
+ self .assertEqual (len (dataset ), num_images )
15
+ img , target = dataset [0 ]
16
+ self .assertTrue (isinstance (img , PIL .Image .Image ))
17
+ self .assertTrue (isinstance (target , int ))
18
+
13
19
def test_imagefolder (self ):
14
20
# TODO: create the fake data on-the-fly
15
21
FAKEDATA_DIR = get_file_path_2 (
@@ -64,47 +70,36 @@ def test_mnist(self, mock_download_extract):
64
70
num_examples = 30
65
71
with mnist_root (num_examples , "MNIST" ) as root :
66
72
dataset = torchvision .datasets .MNIST (root , download = True )
67
- self .assertEqual ( len ( dataset ), num_examples )
73
+ self .generic_classification_dataset_test ( dataset , num_images = num_examples )
68
74
img , target = dataset [0 ]
69
- self .assertTrue (isinstance (img , PIL .Image .Image ))
70
- self .assertTrue (isinstance (target , int ))
75
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
71
76
72
77
@mock .patch ('torchvision.datasets.mnist.download_and_extract_archive' )
73
78
def test_kmnist (self , mock_download_extract ):
74
79
num_examples = 30
75
80
with mnist_root (num_examples , "KMNIST" ) as root :
76
81
dataset = torchvision .datasets .KMNIST (root , download = True )
82
+ self .generic_classification_dataset_test (dataset , num_images = num_examples )
77
83
img , target = dataset [0 ]
78
- self .assertEqual (len (dataset ), num_examples )
79
- self .assertTrue (isinstance (img , PIL .Image .Image ))
80
- self .assertTrue (isinstance (target , int ))
84
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
81
85
82
86
@mock .patch ('torchvision.datasets.mnist.download_and_extract_archive' )
83
87
def test_fashionmnist (self , mock_download_extract ):
84
88
num_examples = 30
85
89
with mnist_root (num_examples , "FashionMNIST" ) as root :
86
90
dataset = torchvision .datasets .FashionMNIST (root , download = True )
91
+ self .generic_classification_dataset_test (dataset , num_images = num_examples )
87
92
img , target = dataset [0 ]
88
- self .assertEqual (len (dataset ), num_examples )
89
- self .assertTrue (isinstance (img , PIL .Image .Image ))
90
- self .assertTrue (isinstance (target , int ))
93
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
91
94
92
95
@mock .patch ('torchvision.datasets.utils.download_url' )
93
96
def test_imagenet (self , mock_download ):
94
97
with imagenet_root () as root :
95
98
dataset = torchvision .datasets .ImageNet (root , split = 'train' , download = True )
96
- self .assertEqual (len (dataset ), 1 )
97
- img , target = dataset [0 ]
98
- self .assertTrue (isinstance (img , PIL .Image .Image ))
99
- self .assertTrue (isinstance (target , int ))
100
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
99
+ self .generic_classification_dataset_test (dataset )
101
100
102
101
dataset = torchvision .datasets .ImageNet (root , split = 'val' , download = True )
103
- self .assertEqual (len (dataset ), 1 )
104
- img , target = dataset [0 ]
105
- self .assertTrue (isinstance (img , PIL .Image .Image ))
106
- self .assertTrue (isinstance (target , int ))
107
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
102
+ self .generic_classification_dataset_test (dataset )
108
103
109
104
@mock .patch ('torchvision.datasets.cifar.check_integrity' )
110
105
@mock .patch ('torchvision.datasets.cifar.CIFAR10._check_integrity' )
@@ -113,18 +108,14 @@ def test_cifar10(self, mock_ext_check, mock_int_check):
113
108
mock_int_check .return_value = True
114
109
with cifar_root ('CIFAR10' ) as root :
115
110
dataset = torchvision .datasets .CIFAR10 (root , train = True , download = True )
116
- self .assertEqual ( len ( dataset ), 5 )
111
+ self .generic_classification_dataset_test ( dataset , num_images = 5 )
117
112
img , target = dataset [0 ]
118
- self .assertTrue (isinstance (img , PIL .Image .Image ))
119
- self .assertTrue (isinstance (target , int ))
120
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
113
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
121
114
122
115
dataset = torchvision .datasets .CIFAR10 (root , train = False , download = True )
123
- self .assertEqual ( len ( dataset ), 1 )
116
+ self .generic_classification_dataset_test ( dataset )
124
117
img , target = dataset [0 ]
125
- self .assertTrue (isinstance (img , PIL .Image .Image ))
126
- self .assertTrue (isinstance (target , int ))
127
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
118
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
128
119
129
120
@mock .patch ('torchvision.datasets.cifar.check_integrity' )
130
121
@mock .patch ('torchvision.datasets.cifar.CIFAR10._check_integrity' )
@@ -133,18 +124,14 @@ def test_cifar100(self, mock_ext_check, mock_int_check):
133
124
mock_int_check .return_value = True
134
125
with cifar_root ('CIFAR100' ) as root :
135
126
dataset = torchvision .datasets .CIFAR100 (root , train = True , download = True )
136
- self .assertEqual ( len ( dataset ), 1 )
127
+ self .generic_classification_dataset_test ( dataset )
137
128
img , target = dataset [0 ]
138
- self .assertTrue (isinstance (img , PIL .Image .Image ))
139
- self .assertTrue (isinstance (target , int ))
140
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
129
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
141
130
142
131
dataset = torchvision .datasets .CIFAR100 (root , train = False , download = True )
143
- self .assertEqual ( len ( dataset ), 1 )
132
+ self .generic_classification_dataset_test ( dataset )
144
133
img , target = dataset [0 ]
145
- self .assertTrue (isinstance (img , PIL .Image .Image ))
146
- self .assertTrue (isinstance (target , int ))
147
- self .assertEqual (dataset .class_to_idx ['fakedata' ], target )
134
+ self .assertEqual (dataset .class_to_idx [dataset .classes [0 ]], target )
148
135
149
136
150
137
if __name__ == '__main__' :
0 commit comments