Skip to content

Commit 6b41eb0

Browse files
authored
add download tests for CIFAR (#2747)
* add download tests for CIFAR * fix tests in case of bad request
1 parent fdca307 commit 6b41eb0

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

.github/workflows/tests-schedule.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
run: pip install pytest
3636

3737
- name: Run tests
38-
run: pytest --durations=20 -ra test/test_datasets_download.py
38+
run: pytest -ra -v test/test_datasets_download.py
3939

4040
- uses: JasonEtco/[email protected]
4141
name: Create issue if download tests failed

test/test_datasets_download.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest.mock
55
from datetime import datetime
66
from os import path
7+
from urllib.error import HTTPError
78
from urllib.parse import urlparse
89
from urllib.request import urlopen, Request
910

@@ -86,25 +87,26 @@ def retry(fn, times=1, wait=5.0):
8687
)
8788

8889

89-
def assert_server_response_ok(response, url=None):
90-
msg = f"The server returned status code {response.code}"
91-
if url is not None:
92-
msg += f"for the the URL {url}"
93-
assert 200 <= response.code < 300, msg
90+
@contextlib.contextmanager
91+
def assert_server_response_ok():
92+
try:
93+
yield
94+
except HTTPError as error:
95+
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
9496

9597

9698
def assert_url_is_accessible(url):
9799
request = Request(url, headers=dict(method="HEAD"))
98-
response = urlopen(request)
99-
assert_server_response_ok(response, url)
100+
with assert_server_response_ok():
101+
urlopen(request)
100102

101103

102104
def assert_file_downloads_correctly(url, md5):
103105
with get_tmp_dir() as root:
104106
file = path.join(root, path.basename(url))
105-
with urlopen(url) as response, open(file, "wb") as fh:
106-
assert_server_response_ok(response, url)
107-
fh.write(response.read())
107+
with assert_server_response_ok():
108+
with urlopen(url) as response, open(file, "wb") as fh:
109+
fh.write(response.read())
108110

109111
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
110112

@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
125127
]
126128

127129

130+
def collect_download_configs(dataset_loader, name):
131+
try:
132+
with log_download_attempts() as urls_and_md5s:
133+
dataset_loader()
134+
except Exception:
135+
pass
136+
137+
return make_download_configs(urls_and_md5s, name)
138+
139+
128140
def places365():
129141
with log_download_attempts(patch=False) as urls_and_md5s:
130142
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
@@ -137,23 +149,19 @@ def places365():
137149

138150

139151
def caltech101():
140-
try:
141-
with log_download_attempts() as urls_and_md5s:
142-
datasets.Caltech101(".", download=True)
143-
except Exception:
144-
pass
145-
146-
return make_download_configs(urls_and_md5s, "Caltech101")
152+
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")
147153

148154

149155
def caltech256():
150-
try:
151-
with log_download_attempts() as urls_and_md5s:
152-
datasets.Caltech256(".", download=True)
153-
except Exception:
154-
pass
156+
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")
157+
158+
159+
def cifar10():
160+
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")
161+
155162

156-
return make_download_configs(urls_and_md5s, "Caltech256")
163+
def cifar100():
164+
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
157165

158166

159167
def make_parametrize_kwargs(download_configs):
@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
166174
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
167175

168176

169-
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
177+
@pytest.mark.parametrize(
178+
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
179+
)
170180
def test_url_is_accessible(url, md5):
171181
retry(lambda: assert_url_is_accessible(url))
172182

0 commit comments

Comments
 (0)