diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index aede262fafe..65f805ce471 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -26,10 +26,11 @@ jobs: - name: Checkout repository uses: actions/checkout@v2 - - name: Install PyTorch from the nightlies - run: | - pip install numpy - pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Install torch nightly build + run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + + - name: Install torchvision + run: pip install -e . - name: Install all optional dataset requirements run: pip install scipy pandas pycocotools lmdb requests diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 82fce713f14..5c4fc54fd5d 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -14,7 +14,7 @@ import pytest from torchvision import datasets -from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive +from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT from common_utils import get_tmp_dir from fakedata_generation import places365_root @@ -150,7 +150,7 @@ def assert_server_response_ok(): def assert_url_is_accessible(url, timeout=5.0): - request = Request(url, headers=dict(method="HEAD")) + request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT}) with assert_server_response_ok(): urlopen(request, timeout=timeout) @@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0): file = path.join(root, path.basename(url)) with assert_server_response_ok(): with open(file, "wb") as fh: - response = urlopen(url, timeout=timeout) + request = Request(url, headers={"User-Agent": USER_AGENT}) + response = urlopen(request, timeout=timeout) fh.write(response.read()) assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 1bd3d3c8053..a27363c533c 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -7,11 +7,28 @@ from typing import Any, Callable, List, Iterable, Optional, TypeVar from urllib.parse import urlparse import zipfile +import urllib +import urllib.request +import urllib.error import torch from torch.utils.model_zoo import tqdm +USER_AGENT = "pytorch/vision" + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + def gen_bar_updater() -> Callable[[int, int, int], None]: pbar = tqdm(total=None) @@ -83,8 +100,6 @@ def download_url( md5 (str, optional): MD5 checksum of the download. If None, do not check max_redirect_hops (int, optional): Maximum number of redirect hops allowed """ - import urllib - root = os.path.expanduser(root) if not filename: filename = os.path.basename(url) @@ -108,19 +123,13 @@ def download_url( # download the file try: print('Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve( - url, fpath, - reporthook=gen_bar_updater() - ) + _urlretrieve(url, fpath) except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve( - url, fpath, - reporthook=gen_bar_updater() - ) + _urlretrieve(url, fpath) else: raise e # check integrity of downloaded file