@@ -46,15 +46,20 @@ def inner_wrapper(request, *args, **kwargs):
46
46
47
47
48
48
@contextlib .contextmanager
49
- def log_download_attempts (urls_and_md5s = None , patch = True , patch_auxiliaries = None ):
49
+ def log_download_attempts (
50
+ urls_and_md5s = None ,
51
+ patch = True ,
52
+ download_url_target = "torchvision.datasets.utils.download_url" ,
53
+ patch_auxiliaries = None ,
54
+ ):
50
55
if urls_and_md5s is None :
51
56
urls_and_md5s = set ()
52
57
if patch_auxiliaries is None :
53
58
patch_auxiliaries = patch
54
59
55
60
with contextlib .ExitStack () as stack :
56
61
download_url_mock = stack .enter_context (
57
- unittest .mock .patch ("torchvision.datasets.utils.download_url" , wraps = None if patch else download_url )
62
+ unittest .mock .patch (download_url_target , wraps = None if patch else download_url )
58
63
)
59
64
if patch_auxiliaries :
60
65
# download_and_extract_archive
@@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None):
127
132
]
128
133
129
134
130
- def collect_download_configs (dataset_loader , name ):
131
- try :
132
- with log_download_attempts () as urls_and_md5s :
133
- dataset_loader ()
134
- except Exception :
135
- pass
136
-
135
+ def collect_download_configs (dataset_loader , name , ** kwargs ):
136
+ with contextlib .suppress (Exception ), log_download_attempts (** kwargs ) as urls_and_md5s :
137
+ dataset_loader ()
137
138
return make_download_configs (urls_and_md5s , name )
138
139
139
140
@@ -164,6 +165,17 @@ def cifar100():
164
165
return collect_download_configs (lambda : datasets .CIFAR10 ("." , download = True ), "CIFAR100" )
165
166
166
167
168
+ def voc ():
169
+ download_configs = []
170
+ for year in ("2007" , "2007-test" , "2008" , "2009" , "2010" , "2011" , "2012" ):
171
+ with contextlib .suppress (Exception ), log_download_attempts (
172
+ download_url_target = "torchvision.datasets.voc.download_url"
173
+ ) as urls_and_md5s :
174
+ datasets .VOCSegmentation ("." , year = year , download = True )
175
+ download_configs .extend (make_download_configs (urls_and_md5s , f"VOC, { year } " ))
176
+ return download_configs
177
+
178
+
167
179
def make_parametrize_kwargs (download_configs ):
168
180
argvalues = []
169
181
ids = []
@@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs):
175
187
176
188
177
189
@pytest .mark .parametrize (
178
- ** make_parametrize_kwargs (itertools .chain (places365 (), caltech101 (), caltech256 (), cifar10 (), cifar100 ()))
190
+ ** make_parametrize_kwargs (
191
+ itertools .chain (
192
+ places365 (),
193
+ caltech101 (),
194
+ caltech256 (),
195
+ cifar10 (),
196
+ cifar100 (),
197
+ voc (),
198
+ )
199
+ )
179
200
)
180
201
def test_url_is_accessible (url , md5 ):
181
202
retry (lambda : assert_url_is_accessible (url ))
0 commit comments