1919from collections .abc import Mapping , Sequence
2020from pathlib import Path
2121from pydoc import locate
22- from shutil import copyfile
22+ from shutil import copyfile , copytree , rmtree
2323from textwrap import dedent
2424from typing import Any , Callable
2525
@@ -193,6 +193,15 @@ def _download_from_ngc(
193193 extractall (filepath = filepath , output_dir = extract_path , has_base = True )
194194
195195
196+ def _download_from_huggingface_hub (repo : str , download_path : str , filename : str ) -> None :
197+ if len (repo .split ("/" )) != 2 :
198+ raise ValueError ("if source is `hf_hub`, repo should be in the form `repo_owner/repo_name`" )
199+ snapshot_folder = huggingface_hub .snapshot_download (repo_id = repo , cache_dir = download_path )
200+ download_dir = os .path .join (download_path , filename )
201+ copytree (snapshot_folder , download_dir , dirs_exist_ok = True )
202+ rmtree (snapshot_folder )
203+
204+
196205def _get_latest_bundle_version (source : str , name : str , repo : str ) -> dict [str , list [str ] | str ] | Any | None :
197206 if source == "ngc" :
198207 name = _add_ngc_prefix (name )
@@ -248,6 +257,9 @@ def download(
248257 # Execute this module as a CLI entry, and download bundle from ngc with latest version:
249258 python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
250259
260+ # Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
261+ python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"
262+
251263 # Execute this module as a CLI entry, and download bundle via URL:
252264 python -m monai.bundle download --name <bundle_name> --url <url>
253265
@@ -271,9 +283,10 @@ def download(
271283 Default is `bundle` subfolder under `torch.hub.get_dir()`.
272284 source: storage location name. This argument is used when `url` is `None`.
273285 In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
274- it should be "ngc" or "github".
275- repo: repo name. This argument is used when `url` is `None` and `source` is "github".
276- If used, it should be in the form of "repo_owner/repo_name/release_tag".
286+ it should be "ngc", "github", or "huggingface_hub".
287+ repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
288+ If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
289+ If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
277290 url: url to download the data. If not `None`, data will be downloaded directly
278291 and `source` will not be checked.
279292 If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
@@ -333,9 +346,17 @@ def download(
333346 remove_prefix = remove_prefix_ ,
334347 progress = progress_ ,
335348 )
349+ elif source_ == "huggingface_hub" :
350+ if name_ is None :
351+ raise ValueError (f"To download from source: 'huggingface_hub', `name` must be provided, got { name_ } ." )
352+ _download_from_huggingface_hub (
353+ repo = repo_ ,
354+ download_path = bundle_dir_ ,
355+ filename = name_
356+ )
336357 else :
337358 raise NotImplementedError (
338- f"Currently only download from `url`, source 'github' or 'ngc ' are implemented, got source: { source_ } ."
359+ f"Currently only download from `url`, source 'github', 'ngc', or 'huggingface_hub ' are implemented, got source: { source_ } ."
339360 )
340361
341362
0 commit comments