diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index a6c272c82e..66c18ea89c 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -1,3 +1,4 @@ +import contextlib import copy import errno import fnmatch @@ -487,9 +488,8 @@ def http_get( ) # Stream file to buffer - progress = _tqdm_bar - if progress is None: - progress = tqdm( + progress_cm: tqdm = ( + tqdm( # type: ignore[assignment] unit="B", unit_scale=True, total=total, @@ -500,71 +500,76 @@ def http_get( # see https://github.com/huggingface/huggingface_hub/pull/2000 name="huggingface_hub.http_get", ) + if _tqdm_bar is None + else contextlib.nullcontext(_tqdm_bar) + # ^ `contextlib.nullcontext` mimics a context manager that does nothing + # Makes it easier to use the same code path for both cases but in the later + # case, the progress bar is not closed when exiting the context manager. + ) - if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE: - supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters - if not supports_callback: - warnings.warn( - "You are using an outdated version of `hf_transfer`. " - "Consider upgrading to latest version to enable progress bars " - "using `pip install -U hf_transfer`." - ) + with progress_cm as progress: + if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE: + supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters + if not supports_callback: + warnings.warn( + "You are using an outdated version of `hf_transfer`. " + "Consider upgrading to latest version to enable progress bars " + "using `pip install -U hf_transfer`." + ) + try: + hf_transfer.download( + url=url, + filename=temp_file.name, + max_files=HF_TRANSFER_CONCURRENCY, + chunk_size=DOWNLOAD_CHUNK_SIZE, + headers=headers, + parallel_failures=3, + max_retries=5, + **({"callback": progress.update} if supports_callback else {}), + ) + except Exception as e: + raise RuntimeError( + "An error occurred while downloading using `hf_transfer`. Consider" + " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." + ) from e + if not supports_callback: + progress.update(total) + if expected_size is not None and expected_size != os.path.getsize(temp_file.name): + raise EnvironmentError( + consistency_error_message.format( + actual_size=os.path.getsize(temp_file.name), + ) + ) + return + new_resume_size = resume_size try: - hf_transfer.download( + for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + new_resume_size += len(chunk) + # Some data has been downloaded from the server so we reset the number of retries. + _nb_retries = 5 + except (requests.ConnectionError, requests.ReadTimeout) as e: + # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely + # a transient error (network outage?). We log a warning message and try to resume the download a few times + # before giving up. Tre retry mechanism is basic but should be enough in most cases. + if _nb_retries <= 0: + logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) + raise + logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) + time.sleep(1) + reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects + return http_get( url=url, - filename=temp_file.name, - max_files=HF_TRANSFER_CONCURRENCY, - chunk_size=DOWNLOAD_CHUNK_SIZE, - headers=headers, - parallel_failures=3, - max_retries=5, - **({"callback": progress.update} if supports_callback else {}), - ) - except Exception as e: - raise RuntimeError( - "An error occurred while downloading using `hf_transfer`. Consider" - " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling." - ) from e - if not supports_callback: - progress.update(total) - if expected_size is not None and expected_size != os.path.getsize(temp_file.name): - raise EnvironmentError( - consistency_error_message.format( - actual_size=os.path.getsize(temp_file.name), - ) + temp_file=temp_file, + proxies=proxies, + resume_size=new_resume_size, + headers=initial_headers, + expected_size=expected_size, + _nb_retries=_nb_retries - 1, + _tqdm_bar=_tqdm_bar, ) - return - new_resume_size = resume_size - try: - for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - new_resume_size += len(chunk) - # Some data has been downloaded from the server so we reset the number of retries. - _nb_retries = 5 - except (requests.ConnectionError, requests.ReadTimeout) as e: - # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely - # a transient error (network outage?). We log a warning message and try to resume the download a few times - # before giving up. Tre retry mechanism is basic but should be enough in most cases. - if _nb_retries <= 0: - logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e)) - raise - logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e)) - time.sleep(1) - reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects - return http_get( - url=url, - temp_file=temp_file, - proxies=proxies, - resume_size=new_resume_size, - headers=initial_headers, - expected_size=expected_size, - _nb_retries=_nb_retries - 1, - _tqdm_bar=_tqdm_bar, - ) - - progress.close() if expected_size is not None and expected_size != temp_file.tell(): raise EnvironmentError(