Skip to content

Commit be3e678

Browse files
committed
Add download from huggingface_hub functionality
1 parent b7d462d commit be3e678

1 file changed

Lines changed: 26 additions & 5 deletions

File tree

monai/bundle/scripts.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from collections.abc import Mapping, Sequence
2020
from pathlib import Path
2121
from pydoc import locate
22-
from shutil import copyfile
22+
from shutil import copyfile, copytree, rmtree
2323
from textwrap import dedent
2424
from typing import Any, Callable
2525

@@ -193,6 +193,15 @@ def _download_from_ngc(
193193
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
194194

195195

196+
def _download_from_huggingface_hub(repo: str, download_path: str, filename: str) -> None:
197+
if len(repo.split("/")) != 2:
198+
raise ValueError("if source is `hf_hub`, repo should be in the form `repo_owner/repo_name`")
199+
snapshot_folder = huggingface_hub.snapshot_download(repo_id=repo, cache_dir=download_path)
200+
download_dir = os.path.join(download_path, filename)
201+
copytree(snapshot_folder, download_dir, dirs_exist_ok=True)
202+
rmtree(snapshot_folder)
203+
204+
196205
def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
197206
if source == "ngc":
198207
name = _add_ngc_prefix(name)
@@ -248,6 +257,9 @@ def download(
248257
# Execute this module as a CLI entry, and download bundle from ngc with latest version:
249258
python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
250259
260+
# Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
261+
python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"
262+
251263
# Execute this module as a CLI entry, and download bundle via URL:
252264
python -m monai.bundle download --name <bundle_name> --url <url>
253265
@@ -271,9 +283,10 @@ def download(
271283
Default is `bundle` subfolder under `torch.hub.get_dir()`.
272284
source: storage location name. This argument is used when `url` is `None`.
273285
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
274-
it should be "ngc" or "github".
275-
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
276-
If used, it should be in the form of "repo_owner/repo_name/release_tag".
286+
it should be "ngc", "github", or "huggingface_hub".
287+
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
288+
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
289+
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
277290
url: url to download the data. If not `None`, data will be downloaded directly
278291
and `source` will not be checked.
279292
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
@@ -333,9 +346,17 @@ def download(
333346
remove_prefix=remove_prefix_,
334347
progress=progress_,
335348
)
349+
elif source_ == "huggingface_hub":
350+
if name_ is None:
351+
raise ValueError(f"To download from source: 'huggingface_hub', `name` must be provided, got {name_}.")
352+
_download_from_huggingface_hub(
353+
repo=repo_,
354+
download_path=bundle_dir_,
355+
filename=name_
356+
)
336357
else:
337358
raise NotImplementedError(
338-
f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}."
359+
f"Currently only download from `url`, source 'github', 'ngc', or 'huggingface_hub' are implemented, got source: {source_}."
339360
)
340361

341362

0 commit comments

Comments
 (0)