4
4
import unittest .mock
5
5
from datetime import datetime
6
6
from os import path
7
+ from urllib .error import HTTPError
7
8
from urllib .parse import urlparse
8
9
from urllib .request import urlopen , Request
9
10
@@ -86,25 +87,26 @@ def retry(fn, times=1, wait=5.0):
86
87
)
87
88
88
89
89
- def assert_server_response_ok (response , url = None ):
90
- msg = f"The server returned status code { response .code } "
91
- if url is not None :
92
- msg += f"for the the URL { url } "
93
- assert 200 <= response .code < 300 , msg
90
+ @contextlib .contextmanager
91
+ def assert_server_response_ok ():
92
+ try :
93
+ yield
94
+ except HTTPError as error :
95
+ raise AssertionError (f"The server returned { error .code } : { error .reason } ." ) from error
94
96
95
97
96
98
def assert_url_is_accessible (url ):
97
99
request = Request (url , headers = dict (method = "HEAD" ))
98
- response = urlopen ( request )
99
- assert_server_response_ok ( response , url )
100
+ with assert_server_response_ok ():
101
+ urlopen ( request )
100
102
101
103
102
104
def assert_file_downloads_correctly (url , md5 ):
103
105
with get_tmp_dir () as root :
104
106
file = path .join (root , path .basename (url ))
105
- with urlopen ( url ) as response , open ( file , "wb" ) as fh :
106
- assert_server_response_ok ( response , url )
107
- fh .write (response .read ())
107
+ with assert_server_response_ok () :
108
+ with urlopen ( url ) as response , open ( file , "wb" ) as fh :
109
+ fh .write (response .read ())
108
110
109
111
assert check_integrity (file , md5 = md5 ), "The MD5 checksums mismatch"
110
112
@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
125
127
]
126
128
127
129
130
+ def collect_download_configs (dataset_loader , name ):
131
+ try :
132
+ with log_download_attempts () as urls_and_md5s :
133
+ dataset_loader ()
134
+ except Exception :
135
+ pass
136
+
137
+ return make_download_configs (urls_and_md5s , name )
138
+
139
+
128
140
def places365 ():
129
141
with log_download_attempts (patch = False ) as urls_and_md5s :
130
142
for split , small in itertools .product (("train-standard" , "train-challenge" , "val" ), (False , True )):
@@ -137,23 +149,19 @@ def places365():
137
149
138
150
139
151
def caltech101 ():
140
- try :
141
- with log_download_attempts () as urls_and_md5s :
142
- datasets .Caltech101 ("." , download = True )
143
- except Exception :
144
- pass
145
-
146
- return make_download_configs (urls_and_md5s , "Caltech101" )
152
+ return collect_download_configs (lambda : datasets .Caltech101 ("." , download = True ), "Caltech101" )
147
153
148
154
149
155
def caltech256 ():
150
- try :
151
- with log_download_attempts () as urls_and_md5s :
152
- datasets .Caltech256 ("." , download = True )
153
- except Exception :
154
- pass
156
+ return collect_download_configs (lambda : datasets .Caltech256 ("." , download = True ), "Caltech256" )
157
+
158
+
159
+ def cifar10 ():
160
+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR10" )
161
+
155
162
156
- return make_download_configs (urls_and_md5s , "Caltech256" )
163
+ def cifar100 ():
164
+ return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR100" )
157
165
158
166
159
167
def make_parametrize_kwargs (download_configs ):
@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
166
174
return dict (argnames = ("url" , "md5" ), argvalues = argvalues , ids = ids )
167
175
168
176
169
- @pytest .mark .parametrize (** make_parametrize_kwargs (itertools .chain (places365 (), caltech101 (), caltech256 ())))
177
+ @pytest .mark .parametrize (
178
+ ** make_parametrize_kwargs (itertools .chain (places365 (), caltech101 (), caltech256 (), cifar10 (), cifar100 ()))
179
+ )
170
180
def test_url_is_accessible (url , md5 ):
171
181
retry (lambda : assert_url_is_accessible (url ))
172
182
0 commit comments