diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index f0c54b6a40b..aede262fafe 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -31,6 +31,9 @@ jobs: pip install numpy pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Install all optional dataset requirements + run: pip install scipy pandas pycocotools lmdb requests + - name: Install tests requirements run: pip install pytest diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index c6e95ffe064..0a49752fe12 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -49,17 +49,22 @@ def inner_wrapper(request, *args, **kwargs): def log_download_attempts( urls_and_md5s=None, patch=True, - download_url_target="torchvision.datasets.utils.download_url", + download_url_location=".utils", patch_auxiliaries=None, ): if urls_and_md5s is None: urls_and_md5s = set() + if download_url_location.startswith("."): + download_url_location = f"torchvision.datasets{download_url_location}" if patch_auxiliaries is None: patch_auxiliaries = patch with contextlib.ExitStack() as stack: download_url_mock = stack.enter_context( - unittest.mock.patch(download_url_target, wraps=None if patch else download_url) + unittest.mock.patch( + f"{download_url_location}.download_url", + wraps=None if patch else download_url, + ) ) if patch_auxiliaries: # download_and_extract_archive @@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None): ] -def collect_download_configs(dataset_loader, name, **kwargs): - with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s: - dataset_loader() +def collect_download_configs(dataset_loader, name=None, **kwargs): + urls_and_md5s = set() + try: + with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs): + dataset = dataset_loader() + except Exception: + dataset = None + + if name is None and dataset is not None: + name = type(dataset).__name__ + return make_download_configs(urls_and_md5s, name) @@ -146,34 +159,40 @@ def places365(): datasets.Places365(root, split=split, small=small, download=True) - return make_download_configs(urls_and_md5s, "Places365") + return make_download_configs(urls_and_md5s, name="Places365") def caltech101(): - return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101") + return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101") def caltech256(): - return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256") + return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256") def cifar10(): - return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10") + return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10") def cifar100(): - return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100") + return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100") def voc(): - download_configs = [] - for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"): - with contextlib.suppress(Exception), log_download_attempts( - download_url_target="torchvision.datasets.voc.download_url" - ) as urls_and_md5s: - datasets.VOCSegmentation(".", year=year, download=True) - download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}")) - return download_configs + return itertools.chain( + *[ + collect_download_configs( + lambda: datasets.VOCSegmentation(".", year=year, download=True), + name=f"VOC, {year}", + download_url_location=".voc", + ) + for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012") + ] + ) + + +def mnist(): + return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST") def make_parametrize_kwargs(download_configs): @@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs): cifar100(), # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details. # voc(), + mnist(), ) ) )