@@ -56,18 +56,14 @@ def test_info(self, name):
56
56
57
57
@parametrize_dataset_mocks (DATASET_MOCKS )
58
58
def test_smoke (self , dataset_mock , config ):
59
- dataset_mock .prepare (config )
60
-
61
- dataset = datasets .load (dataset_mock .name , ** config )
59
+ dataset , _ = dataset_mock .load (config )
62
60
63
61
if not isinstance (dataset , datasets .utils .Dataset ):
64
62
raise AssertionError (f"Loading the dataset should return an Dataset, but got { type (dataset )} instead." )
65
63
66
64
@parametrize_dataset_mocks (DATASET_MOCKS )
67
65
def test_sample (self , dataset_mock , config ):
68
- dataset_mock .prepare (config )
69
-
70
- dataset = datasets .load (dataset_mock .name , ** config )
66
+ dataset , _ = dataset_mock .load (config )
71
67
72
68
try :
73
69
sample = next (iter (dataset ))
@@ -84,17 +80,13 @@ def test_sample(self, dataset_mock, config):
84
80
85
81
@parametrize_dataset_mocks (DATASET_MOCKS )
86
82
def test_num_samples (self , dataset_mock , config ):
87
- mock_info = dataset_mock .prepare (config )
88
-
89
- dataset = datasets .load (dataset_mock .name , ** config )
83
+ dataset , mock_info = dataset_mock .load (config )
90
84
91
85
assert len (list (dataset )) == mock_info ["num_samples" ]
92
86
93
87
@parametrize_dataset_mocks (DATASET_MOCKS )
94
88
def test_no_vanilla_tensors (self , dataset_mock , config ):
95
- dataset_mock .prepare (config )
96
-
97
- dataset = datasets .load (dataset_mock .name , ** config )
89
+ dataset , _ = dataset_mock .load (config )
98
90
99
91
vanilla_tensors = {key for key , value in next (iter (dataset )).items () if type (value ) is torch .Tensor }
100
92
if vanilla_tensors :
@@ -105,24 +97,20 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
105
97
106
98
@parametrize_dataset_mocks (DATASET_MOCKS )
107
99
def test_transformable (self , dataset_mock , config ):
108
- dataset_mock .prepare (config )
109
-
110
- dataset = datasets .load (dataset_mock .name , ** config )
100
+ dataset , _ = dataset_mock .load (config )
111
101
112
102
next (iter (dataset .map (transforms .Identity ())))
113
103
114
104
@pytest .mark .parametrize ("only_datapipe" , [False , True ])
115
105
@parametrize_dataset_mocks (DATASET_MOCKS )
116
106
def test_traversable (self , dataset_mock , config , only_datapipe ):
117
- dataset_mock .prepare (config )
118
- dataset = datasets .load (dataset_mock .name , ** config )
107
+ dataset , _ = dataset_mock .load (config )
119
108
120
109
traverse (dataset , only_datapipe = only_datapipe )
121
110
122
111
@parametrize_dataset_mocks (DATASET_MOCKS )
123
112
def test_serializable (self , dataset_mock , config ):
124
- dataset_mock .prepare (config )
125
- dataset = datasets .load (dataset_mock .name , ** config )
113
+ dataset , _ = dataset_mock .load (config )
126
114
127
115
pickle .dumps (dataset )
128
116
@@ -135,8 +123,7 @@ def _collate_fn(self, batch):
135
123
@pytest .mark .parametrize ("num_workers" , [0 , 1 ])
136
124
@parametrize_dataset_mocks (DATASET_MOCKS )
137
125
def test_data_loader (self , dataset_mock , config , num_workers ):
138
- dataset_mock .prepare (config )
139
- dataset = datasets .load (dataset_mock .name , ** config )
126
+ dataset , _ = dataset_mock .load (config )
140
127
141
128
dl = DataLoader (
142
129
dataset ,
@@ -153,17 +140,15 @@ def test_data_loader(self, dataset_mock, config, num_workers):
153
140
@parametrize_dataset_mocks (DATASET_MOCKS )
154
141
@pytest .mark .parametrize ("annotation_dp_type" , (Shuffler , ShardingFilter ))
155
142
def test_has_annotations (self , dataset_mock , config , annotation_dp_type ):
156
-
157
- dataset_mock .prepare (config )
158
- dataset = datasets .load (dataset_mock .name , ** config )
143
+ dataset , _ = dataset_mock .load (config )
159
144
160
145
if not any (isinstance (dp , annotation_dp_type ) for dp in extract_datapipes (dataset )):
161
146
raise AssertionError (f"The dataset doesn't contain a { annotation_dp_type .__name__ } () datapipe." )
162
147
163
148
@parametrize_dataset_mocks (DATASET_MOCKS )
164
149
def test_save_load (self , dataset_mock , config ):
165
- dataset_mock .prepare (config )
166
- dataset = datasets . load ( dataset_mock . name , ** config )
150
+ dataset , _ = dataset_mock .load (config )
151
+
167
152
sample = next (iter (dataset ))
168
153
169
154
with io .BytesIO () as buffer :
@@ -173,8 +158,7 @@ def test_save_load(self, dataset_mock, config):
173
158
174
159
@parametrize_dataset_mocks (DATASET_MOCKS )
175
160
def test_infinite_buffer_size (self , dataset_mock , config ):
176
- dataset_mock .prepare (config )
177
- dataset = datasets .load (dataset_mock .name , ** config )
161
+ dataset , _ = dataset_mock .load (config )
178
162
179
163
for dp in extract_datapipes (dataset ):
180
164
if hasattr (dp , "buffer_size" ):
@@ -184,18 +168,15 @@ def test_infinite_buffer_size(self, dataset_mock, config):
184
168
185
169
@parametrize_dataset_mocks (DATASET_MOCKS )
186
170
def test_has_length (self , dataset_mock , config ):
187
- dataset_mock .prepare (config )
188
- dataset = datasets .load (dataset_mock .name , ** config )
171
+ dataset , _ = dataset_mock .load (config )
189
172
190
173
assert len (dataset ) > 0
191
174
192
175
193
176
@parametrize_dataset_mocks (DATASET_MOCKS ["qmnist" ])
194
177
class TestQMNIST :
195
178
def test_extra_label (self , dataset_mock , config ):
196
- dataset_mock .prepare (config )
197
-
198
- dataset = datasets .load (dataset_mock .name , ** config )
179
+ dataset , _ = dataset_mock .load (config )
199
180
200
181
sample = next (iter (dataset ))
201
182
for key , type in (
@@ -218,9 +199,7 @@ def test_label_matches_path(self, dataset_mock, config):
218
199
if config ["split" ] != "train" :
219
200
return
220
201
221
- dataset_mock .prepare (config )
222
-
223
- dataset = datasets .load (dataset_mock .name , ** config )
202
+ dataset , _ = dataset_mock .load (config )
224
203
225
204
for sample in dataset :
226
205
label_from_path = int (Path (sample ["path" ]).parent .name )
@@ -230,9 +209,7 @@ def test_label_matches_path(self, dataset_mock, config):
230
209
@parametrize_dataset_mocks (DATASET_MOCKS ["usps" ])
231
210
class TestUSPS :
232
211
def test_sample_content (self , dataset_mock , config ):
233
- dataset_mock .prepare (config )
234
-
235
- dataset = datasets .load (dataset_mock .name , ** config )
212
+ dataset , _ = dataset_mock .load (config )
236
213
237
214
for sample in dataset :
238
215
assert "image" in sample
0 commit comments