@@ -25,9 +25,10 @@ def extract_datapipes(dp):
25
25
return get_all_graph_pipes (traverse (dp , only_datapipe = True ))
26
26
27
27
28
- @pytest .fixture
28
+ @pytest .fixture ( autouse = True )
29
29
def test_home (mocker , tmp_path ):
30
30
mocker .patch ("torchvision.prototype.datasets._api.home" , return_value = str (tmp_path ))
31
+ mocker .patch ("torchvision.prototype.datasets.home" , return_value = str (tmp_path ))
31
32
yield tmp_path
32
33
33
34
@@ -54,17 +55,17 @@ def test_info(self, name):
54
55
raise AssertionError ("Info should be a dictionary with string keys." )
55
56
56
57
@parametrize_dataset_mocks (DATASET_MOCKS )
57
- def test_smoke (self , test_home , dataset_mock , config ):
58
- dataset_mock .prepare (test_home , config )
58
+ def test_smoke (self , dataset_mock , config ):
59
+ dataset_mock .prepare (config )
59
60
60
61
dataset = datasets .load (dataset_mock .name , ** config )
61
62
62
63
if not isinstance (dataset , datasets .utils .Dataset ):
63
64
raise AssertionError (f"Loading the dataset should return an Dataset, but got { type (dataset )} instead." )
64
65
65
66
@parametrize_dataset_mocks (DATASET_MOCKS )
66
- def test_sample (self , test_home , dataset_mock , config ):
67
- dataset_mock .prepare (test_home , config )
67
+ def test_sample (self , dataset_mock , config ):
68
+ dataset_mock .prepare (config )
68
69
69
70
dataset = datasets .load (dataset_mock .name , ** config )
70
71
@@ -82,16 +83,16 @@ def test_sample(self, test_home, dataset_mock, config):
82
83
raise AssertionError ("Sample dictionary is empty." )
83
84
84
85
@parametrize_dataset_mocks (DATASET_MOCKS )
85
- def test_num_samples (self , test_home , dataset_mock , config ):
86
- mock_info = dataset_mock .prepare (test_home , config )
86
+ def test_num_samples (self , dataset_mock , config ):
87
+ mock_info = dataset_mock .prepare (config )
87
88
88
89
dataset = datasets .load (dataset_mock .name , ** config )
89
90
90
91
assert len (list (dataset )) == mock_info ["num_samples" ]
91
92
92
93
@parametrize_dataset_mocks (DATASET_MOCKS )
93
- def test_no_vanilla_tensors (self , test_home , dataset_mock , config ):
94
- dataset_mock .prepare (test_home , config )
94
+ def test_no_vanilla_tensors (self , dataset_mock , config ):
95
+ dataset_mock .prepare (config )
95
96
96
97
dataset = datasets .load (dataset_mock .name , ** config )
97
98
@@ -103,24 +104,24 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
103
104
)
104
105
105
106
@parametrize_dataset_mocks (DATASET_MOCKS )
106
- def test_transformable (self , test_home , dataset_mock , config ):
107
- dataset_mock .prepare (test_home , config )
107
+ def test_transformable (self , dataset_mock , config ):
108
+ dataset_mock .prepare (config )
108
109
109
110
dataset = datasets .load (dataset_mock .name , ** config )
110
111
111
112
next (iter (dataset .map (transforms .Identity ())))
112
113
113
114
@pytest .mark .parametrize ("only_datapipe" , [False , True ])
114
115
@parametrize_dataset_mocks (DATASET_MOCKS )
115
- def test_traversable (self , test_home , dataset_mock , config , only_datapipe ):
116
- dataset_mock .prepare (test_home , config )
116
+ def test_traversable (self , dataset_mock , config , only_datapipe ):
117
+ dataset_mock .prepare (config )
117
118
dataset = datasets .load (dataset_mock .name , ** config )
118
119
119
120
traverse (dataset , only_datapipe = only_datapipe )
120
121
121
122
@parametrize_dataset_mocks (DATASET_MOCKS )
122
- def test_serializable (self , test_home , dataset_mock , config ):
123
- dataset_mock .prepare (test_home , config )
123
+ def test_serializable (self , dataset_mock , config ):
124
+ dataset_mock .prepare (config )
124
125
dataset = datasets .load (dataset_mock .name , ** config )
125
126
126
127
pickle .dumps (dataset )
@@ -133,8 +134,8 @@ def _collate_fn(self, batch):
133
134
134
135
@pytest .mark .parametrize ("num_workers" , [0 , 1 ])
135
136
@parametrize_dataset_mocks (DATASET_MOCKS )
136
- def test_data_loader (self , test_home , dataset_mock , config , num_workers ):
137
- dataset_mock .prepare (test_home , config )
137
+ def test_data_loader (self , dataset_mock , config , num_workers ):
138
+ dataset_mock .prepare (config )
138
139
dataset = datasets .load (dataset_mock .name , ** config )
139
140
140
141
dl = DataLoader (
@@ -151,17 +152,17 @@ def test_data_loader(self, test_home, dataset_mock, config, num_workers):
151
152
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
152
153
@parametrize_dataset_mocks (DATASET_MOCKS )
153
154
@pytest .mark .parametrize ("annotation_dp_type" , (Shuffler , ShardingFilter ))
154
- def test_has_annotations (self , test_home , dataset_mock , config , annotation_dp_type ):
155
+ def test_has_annotations (self , dataset_mock , config , annotation_dp_type ):
155
156
156
- dataset_mock .prepare (test_home , config )
157
+ dataset_mock .prepare (config )
157
158
dataset = datasets .load (dataset_mock .name , ** config )
158
159
159
160
if not any (isinstance (dp , annotation_dp_type ) for dp in extract_datapipes (dataset )):
160
161
raise AssertionError (f"The dataset doesn't contain a { annotation_dp_type .__name__ } () datapipe." )
161
162
162
163
@parametrize_dataset_mocks (DATASET_MOCKS )
163
- def test_save_load (self , test_home , dataset_mock , config ):
164
- dataset_mock .prepare (test_home , config )
164
+ def test_save_load (self , dataset_mock , config ):
165
+ dataset_mock .prepare (config )
165
166
dataset = datasets .load (dataset_mock .name , ** config )
166
167
sample = next (iter (dataset ))
167
168
@@ -171,8 +172,8 @@ def test_save_load(self, test_home, dataset_mock, config):
171
172
assert_samples_equal (torch .load (buffer ), sample )
172
173
173
174
@parametrize_dataset_mocks (DATASET_MOCKS )
174
- def test_infinite_buffer_size (self , test_home , dataset_mock , config ):
175
- dataset_mock .prepare (test_home , config )
175
+ def test_infinite_buffer_size (self , dataset_mock , config ):
176
+ dataset_mock .prepare (config )
176
177
dataset = datasets .load (dataset_mock .name , ** config )
177
178
178
179
for dp in extract_datapipes (dataset ):
@@ -182,17 +183,17 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
182
183
assert dp .buffer_size == INFINITE_BUFFER_SIZE
183
184
184
185
@parametrize_dataset_mocks (DATASET_MOCKS )
185
- def test_has_length (self , test_home , dataset_mock , config ):
186
- dataset_mock .prepare (test_home , config )
186
+ def test_has_length (self , dataset_mock , config ):
187
+ dataset_mock .prepare (config )
187
188
dataset = datasets .load (dataset_mock .name , ** config )
188
189
189
190
assert len (dataset ) > 0
190
191
191
192
192
193
@parametrize_dataset_mocks (DATASET_MOCKS ["qmnist" ])
193
194
class TestQMNIST :
194
- def test_extra_label (self , test_home , dataset_mock , config ):
195
- dataset_mock .prepare (test_home , config )
195
+ def test_extra_label (self , dataset_mock , config ):
196
+ dataset_mock .prepare (config )
196
197
197
198
dataset = datasets .load (dataset_mock .name , ** config )
198
199
@@ -211,13 +212,13 @@ def test_extra_label(self, test_home, dataset_mock, config):
211
212
212
213
@parametrize_dataset_mocks (DATASET_MOCKS ["gtsrb" ])
213
214
class TestGTSRB :
214
- def test_label_matches_path (self , test_home , dataset_mock , config ):
215
+ def test_label_matches_path (self , dataset_mock , config ):
215
216
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
216
217
# This test makes sure that they're both the same
217
218
if config ["split" ] != "train" :
218
219
return
219
220
220
- dataset_mock .prepare (test_home , config )
221
+ dataset_mock .prepare (config )
221
222
222
223
dataset = datasets .load (dataset_mock .name , ** config )
223
224
@@ -228,8 +229,8 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
228
229
229
230
@parametrize_dataset_mocks (DATASET_MOCKS ["usps" ])
230
231
class TestUSPS :
231
- def test_sample_content (self , test_home , dataset_mock , config ):
232
- dataset_mock .prepare (test_home , config )
232
+ def test_sample_content (self , dataset_mock , config ):
233
+ dataset_mock .prepare (config )
233
234
234
235
dataset = datasets .load (dataset_mock .name , ** config )
235
236
0 commit comments