Skip to content

Commit d1e134c

Browse files
authored
add download tests for VOC (#2834)
1 parent 57b653c commit d1e134c

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

test/test_datasets_download.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,20 @@ def inner_wrapper(request, *args, **kwargs):
4646

4747

4848
@contextlib.contextmanager
49-
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
49+
def log_download_attempts(
50+
urls_and_md5s=None,
51+
patch=True,
52+
download_url_target="torchvision.datasets.utils.download_url",
53+
patch_auxiliaries=None,
54+
):
5055
if urls_and_md5s is None:
5156
urls_and_md5s = set()
5257
if patch_auxiliaries is None:
5358
patch_auxiliaries = patch
5459

5560
with contextlib.ExitStack() as stack:
5661
download_url_mock = stack.enter_context(
57-
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
62+
unittest.mock.patch(download_url_target, wraps=None if patch else download_url)
5863
)
5964
if patch_auxiliaries:
6065
# download_and_extract_archive
@@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None):
127132
]
128133

129134

130-
def collect_download_configs(dataset_loader, name):
131-
try:
132-
with log_download_attempts() as urls_and_md5s:
133-
dataset_loader()
134-
except Exception:
135-
pass
136-
135+
def collect_download_configs(dataset_loader, name, **kwargs):
136+
with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s:
137+
dataset_loader()
137138
return make_download_configs(urls_and_md5s, name)
138139

139140

@@ -164,6 +165,17 @@ def cifar100():
164165
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
165166

166167

168+
def voc():
169+
download_configs = []
170+
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"):
171+
with contextlib.suppress(Exception), log_download_attempts(
172+
download_url_target="torchvision.datasets.voc.download_url"
173+
) as urls_and_md5s:
174+
datasets.VOCSegmentation(".", year=year, download=True)
175+
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}"))
176+
return download_configs
177+
178+
167179
def make_parametrize_kwargs(download_configs):
168180
argvalues = []
169181
ids = []
@@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs):
175187

176188

177189
@pytest.mark.parametrize(
178-
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
190+
**make_parametrize_kwargs(
191+
itertools.chain(
192+
places365(),
193+
caltech101(),
194+
caltech256(),
195+
cifar10(),
196+
cifar100(),
197+
voc(),
198+
)
199+
)
179200
)
180201
def test_url_is_accessible(url, md5):
181202
retry(lambda: assert_url_is_accessible(url))

0 commit comments

Comments
 (0)