diff --git a/.github/workflows/python-quality.yml b/.github/workflows/python-quality.yml new file mode 100644 index 0000000000..0c9b2d634a --- /dev/null +++ b/.github/workflows/python-quality.yml @@ -0,0 +1,24 @@ +name: Python quality + +on: + push: + branches: + - "*" + +jobs: + check_code_quality: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[dev] + - run: black --check tests src + - run: isort --check-only tests src + - run: flake8 tests src diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000000..4f51aa93a0 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,31 @@ +name: Python tests + +on: + push: + branches: + - "*" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - run: | + git config --global user.email "ci@dummy.com" + git config --global user.name "ci" + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[testing] + - name: pytest + run: RUN_GIT_LFS_TESTS=1 pytest -sv ./tests/ diff --git a/.gitignore b/.gitignore index b6e47617de..c85b28c30f 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ dmypy.json # Pyre type checker .pyre/ +.vscode/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..99edd4c282 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +.PHONY: quality style test + + +check_dirs := tests src + + +quality: + black --check $(check_dirs) + isort --check-only $(check_dirs) + flake8 $(check_dirs) + +style: + black $(check_dirs) + isort $(check_dirs) + +test: + pytest -sv ./tests/ + diff --git a/README.md b/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..1d9738d95d --- /dev/null +++ b/setup.cfg @@ -0,0 +1,50 @@ +[isort] +default_section = FIRSTPARTY +ensure_newline_before_comments = True +force_grid_wrap = 0 +include_trailing_comma = True +known_first_party = huggingface_hub +known_third_party = + absl + conllu + datasets + elasticsearch + fairseq + faiss-cpu + fastprogress + fire + fugashi + git + h5py + matplotlib + nltk + numpy + packaging + pandas + PIL + psutil + pytest + pytorch_lightning + rouge_score + sacrebleu + seqeval + sklearn + streamlit + tensorboardX + tensorflow + tensorflow_datasets + timeout_decorator + torch + torchtext + torchvision + torch_xla + tqdm + +line_length = 88 +lines_after_imports = 2 +multi_line_output = 3 +use_parentheses = True + +[flake8] +ignore = E203, E501, E741, W503, W605 +max-line-length = 88 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..f7ef2b0df2 --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +from setuptools import find_packages, setup + + +def get_version() -> str: + rel_path = "src/huggingface_hub/__init__.py" + with open(rel_path, "r") as fp: + for line in fp.read().splitlines(): + if line.startswith("__version__"): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + raise RuntimeError("Unable to find version string.") + + +install_requires = [ + "filelock", + "requests", + "tqdm", +] + +extras = {} + +extras["testing"] = [ + "pytest", +] + +extras["quality"] = [ + "black>=20.8b1", + "isort>=5.5.4", + "flake8>=3.8.3", +] + +extras["all"] = extras["testing"] + extras["quality"] + +extras["dev"] = extras["all"] + + +setup( + name="huggingface_hub", + version=get_version(), + author="Hugging Face, Inc.", + author_email="julien@huggingface.co", + description="Client library to download and publish models on the huggingface.co hub", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords="model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models", + license="Apache", + url="https://github.com/huggingface/huggingface_hub", + package_dir={"": "src"}, + packages=find_packages("src"), + extras_require=extras, + entry_points={ + "console_scripts": [ + "huggingface-cli=huggingface_hub.commands.huggingface_cli:main" + ] + }, + python_requires=">=3.6.0", + install_requires=install_requires, + classifiers=[ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py new file mode 100644 index 0000000000..9daf268ef2 --- /dev/null +++ b/src/huggingface_hub/__init__.py @@ -0,0 +1,22 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.1.0" + +from .file_download import HUGGINGFACE_CO_URL_TEMPLATE, cached_download, hf_hub_url +from .hf_api import HfApi, HfFolder diff --git a/src/huggingface_hub/commands/__init__.py b/src/huggingface_hub/commands/__init__.py new file mode 100644 index 0000000000..5ecef032b9 --- /dev/null +++ b/src/huggingface_hub/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import ArgumentParser + + +class BaseHuggingfaceCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/src/huggingface_hub/commands/huggingface_cli.py b/src/huggingface_hub/commands/huggingface_cli.py new file mode 100644 index 0000000000..ca140ad8bb --- /dev/null +++ b/src/huggingface_hub/commands/huggingface_cli.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from huggingface_hub.commands.lfs import LfsCommands +from huggingface_hub.commands.user import UserCommands + + +def main(): + parser = ArgumentParser( + "huggingface-cli", usage="huggingface-cli []" + ) + commands_parser = parser.add_subparsers(help="huggingface-cli command helpers") + + # Register commands + UserCommands.register_subcommand(commands_parser) + LfsCommands.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/src/huggingface_hub/commands/lfs.py b/src/huggingface_hub/commands/lfs.py new file mode 100644 index 0000000000..866c787b90 --- /dev/null +++ b/src/huggingface_hub/commands/lfs.py @@ -0,0 +1,228 @@ +""" +Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. + +Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py + +Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + +To launch debugger while developing: + +``` [lfs "customtransfer.multipart"] + +path = /path/to/huggingface_hub/.env/bin/python + +args = -m debugpy --listen 5678 --wait-for-client /path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py +lfs-multipart-upload ``` +""" + +import json +import logging +import os +import subprocess +import sys +from argparse import ArgumentParser +from contextlib import AbstractContextManager +from typing import Dict, List, Optional + +import requests +from huggingface_hub.commands import BaseHuggingfaceCLICommand + + +logger = logging.getLogger(__name__) + + +LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload" + + +class LfsCommands(BaseHuggingfaceCLICommand): + """ + Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload + large files >5GB 🔥. Spec for LFS custom transfer agent is: + https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + This introduces two commands to the CLI: + + 1. $ huggingface-cli lfs-enable-largefiles + + This should be executed once for each model repo that contains a model file >5GB. It's documented in the error + message you get if you just try to git push a 5GB file without having enabled it before. + + 2. $ huggingface-cli lfs-multipart-upload + + This command is called by lfs directly and is not meant to be called by the user. + """ + + @staticmethod + def register_subcommand(parser: ArgumentParser): + enable_parser = parser.add_parser( + "lfs-enable-largefiles", + help="Configure your repository to enable upload of files > 5GB.", + ) + enable_parser.add_argument( + "path", type=str, help="Local path to repository you want to configure." + ) + enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args)) + + upload_parser = parser.add_parser( + LFS_MULTIPART_UPLOAD_COMMAND, + help="Command will get called by git-lfs, do not call it directly.", + ) + upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args)) + + +class LfsEnableCommand: + def __init__(self, args): + self.args = args + + def run(self): + local_path = os.path.abspath(self.args.path) + if not os.path.isdir(local_path): + print("This does not look like a valid git repo.") + exit(1) + subprocess.run( + "git config lfs.customtransfer.multipart.path huggingface-cli".split(), + check=True, + cwd=local_path, + ) + subprocess.run( + f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(), + check=True, + cwd=local_path, + ) + print("Local repo set up for largefiles") + + +def write_msg(msg: Dict): + """Write out the message in Line delimited JSON.""" + msg = json.dumps(msg) + "\n" + sys.stdout.write(msg) + sys.stdout.flush() + + +def read_msg() -> Optional[Dict]: + """Read Line delimited JSON from stdin. """ + msg = json.loads(sys.stdin.readline().strip()) + + if "terminate" in (msg.get("type"), msg.get("event")): + # terminate message received + return None + + if msg.get("event") not in ("download", "upload"): + logger.critical("Received unexpected message") + sys.exit(1) + + return msg + + +class FileSlice(AbstractContextManager): + """ + File-like object that only reads a slice of a file + + Inspired by stackoverflow.com/a/29838711/593036 + """ + + def __init__(self, filepath: str, seek_from: int, read_limit: int): + self.filepath = filepath + self.seek_from = seek_from + self.read_limit = read_limit + self.n_seen = 0 + + def __enter__(self): + self.f = open(self.filepath, "rb") + self.f.seek(self.seek_from) + return self + + def __len__(self): + total_length = os.fstat(self.f.fileno()).st_size + return min(self.read_limit, total_length - self.seek_from) + + def read(self, n=-1): + if self.n_seen >= self.read_limit: + return b"" + remaining_amount = self.read_limit - self.n_seen + data = self.f.read(remaining_amount if n < 0 else min(n, remaining_amount)) + self.n_seen += len(data) + return data + + def __iter__(self): + yield self.read(n=4 * 1024 * 1024) + + def __exit__(self, *args): + self.f.close() + + +class LfsUploadCommand: + def __init__(self, args): + self.args = args + + def run(self): + # Immediately after invoking a custom transfer process, git-lfs + # sends initiation data to the process over stdin. + # This tells the process useful information about the configuration. + init_msg = json.loads(sys.stdin.readline().strip()) + if not ( + init_msg.get("event") == "init" and init_msg.get("operation") == "upload" + ): + write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}}) + sys.exit(1) + + # The transfer process should use the information it needs from the + # initiation structure, and also perform any one-off setup tasks it + # needs to do. It should then respond on stdout with a simple empty + # confirmation structure, as follows: + write_msg({}) + + # After the initiation exchange, git-lfs will send any number of + # transfer requests to the stdin of the transfer process, in a serial sequence. + while True: + msg = read_msg() + if msg is None: + # When all transfers have been processed, git-lfs will send + # a terminate event to the stdin of the transfer process. + # On receiving this message the transfer process should + # clean up and terminate. No response is expected. + sys.exit(0) + + oid = msg["oid"] + filepath = msg["path"] + completion_url = msg["action"]["href"] + header = msg["action"]["header"] + chunk_size = int(header.pop("chunk_size")) + presigned_urls: List[str] = list(header.values()) + + parts = [] + for i, presigned_url in enumerate(presigned_urls): + with FileSlice( + filepath, seek_from=i * chunk_size, read_limit=chunk_size + ) as data: + r = requests.put(presigned_url, data=data) + r.raise_for_status() + parts.append( + { + "etag": r.headers.get("etag"), + "partNumber": i + 1, + } + ) + # In order to support progress reporting while data is uploading / downloading, + # the transfer process should post messages to stdout + write_msg( + { + "event": "progress", + "oid": oid, + "bytesSoFar": (i + 1) * chunk_size, + "bytesSinceLast": chunk_size, + } + ) + # Not precise but that's ok. + + r = requests.post( + completion_url, + json={ + "oid": oid, + "parts": parts, + }, + ) + r.raise_for_status() + + write_msg({"event": "complete", "oid": oid}) diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py new file mode 100644 index 0000000000..00075e507f --- /dev/null +++ b/src/huggingface_hub/commands/user.py @@ -0,0 +1,252 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +from argparse import ArgumentParser +from getpass import getpass +from typing import List, Union + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.hf_api import HfApi, HfFolder +from requests.exceptions import HTTPError + + +class UserCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + login_parser = parser.add_parser( + "login", help="Log in using the same credentials as on huggingface.co" + ) + login_parser.set_defaults(func=lambda args: LoginCommand(args)) + whoami_parser = parser.add_parser( + "whoami", help="Find out which huggingface.co account you are logged in as." + ) + whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) + logout_parser = parser.add_parser("logout", help="Log out") + logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) + + # new system: git-based repo system + repo_parser = parser.add_parser( + "repo", + help="{create, ls-files} Commands to interact with your huggingface.co repos.", + ) + repo_subparsers = repo_parser.add_subparsers( + help="huggingface.co repos related commands" + ) + ls_parser = repo_subparsers.add_parser( + "ls-files", help="List all your files on huggingface.co" + ) + ls_parser.add_argument( + "--organization", type=str, help="Optional: organization namespace." + ) + ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args)) + repo_create_parser = repo_subparsers.add_parser( + "create", help="Create a new repo on huggingface.co" + ) + repo_create_parser.add_argument( + "name", + type=str, + help="Name for your model's repo. Will be namespaced under your username to build the model id.", + ) + repo_create_parser.add_argument( + "--organization", type=str, help="Optional: organization namespace." + ) + repo_create_parser.add_argument( + "-y", + "--yes", + action="store_true", + help="Optional: answer Yes to the prompt", + ) + repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args)) + + +class ANSI: + """ + Helper for en.wikipedia.org/wiki/ANSI_escape_code + """ + + _bold = "\u001b[1m" + _red = "\u001b[31m" + _gray = "\u001b[90m" + _reset = "\u001b[0m" + + @classmethod + def bold(cls, s): + return "{}{}{}".format(cls._bold, s, cls._reset) + + @classmethod + def red(cls, s): + return "{}{}{}".format(cls._bold + cls._red, s, cls._reset) + + @classmethod + def gray(cls, s): + return "{}{}{}".format(cls._gray, s, cls._reset) + + +def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + lines.append(row_format.format(*row)) + return "\n".join(lines) + + +class BaseUserCommand: + def __init__(self, args): + self.args = args + self._api = HfApi() + + +class LoginCommand(BaseUserCommand): + def run(self): + print( # docstyle-ignore + """ + _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| + _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| + _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| + _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| + + """ + ) + username = input("Username: ") + password = getpass() + try: + token = self._api.login(username, password) + except HTTPError as e: + # probably invalid credentials, display error message. + print(e) + print(ANSI.red(e.response.text)) + exit(1) + HfFolder.save_token(token) + print("Login successful") + print("Your token:", token, "\n") + print("Your token has been saved to", HfFolder.path_token) + + +class WhoamiCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit() + try: + user, orgs = self._api.whoami(token) + print(user) + if orgs: + print(ANSI.bold("orgs: "), ",".join(orgs)) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + + +class LogoutCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit() + HfFolder.delete_token() + self._api.logout(token) + print("Successfully logged out.") + + +class ListReposObjsCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit(1) + try: + objs = self._api.list_repos_objs(token, organization=self.args.organization) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + if len(objs) == 0: + print("No shared file yet") + exit() + rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs] + print( + tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]) + ) + + +class RepoCreateCommand(BaseUserCommand): + def run(self): + token = HfFolder.get_token() + if token is None: + print("Not logged in") + exit(1) + try: + stdout = subprocess.check_output(["git", "--version"]).decode("utf-8") + print(ANSI.gray(stdout.strip())) + except FileNotFoundError: + print("Looks like you do not have git installed, please install.") + + try: + stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8") + print(ANSI.gray(stdout.strip())) + except FileNotFoundError: + print( + ANSI.red( + "Looks like you do not have git-lfs installed, please install." + " You can install from https://git-lfs.github.com/." + " Then run `git lfs install` (you only have to do this once)." + ) + ) + print("") + + user, _ = self._api.whoami(token) + namespace = ( + self.args.organization if self.args.organization is not None else user + ) + + print( + "You are about to create {}".format( + ANSI.bold(namespace + "/" + self.args.name) + ) + ) + + if not self.args.yes: + choice = input("Proceed? [Y/n] ").lower() + if not (choice == "" or choice == "y" or choice == "yes"): + print("Abort") + exit() + try: + url = self._api.create_repo( + token, name=self.args.name, organization=self.args.organization + ) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + print("\nYour repo now lives at:") + print(" {}".format(ANSI.bold(url))) + print( + "\nYou can clone it locally with the command below," + " and commit/push as usual." + ) + print(f"\n git clone {url}") + print("") diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py new file mode 100644 index 0000000000..bf7e5b3141 --- /dev/null +++ b/src/huggingface_hub/file_download.py @@ -0,0 +1,375 @@ +import copy +import fnmatch +import io +import json +import logging +import os +import sys +import tempfile +from contextlib import contextmanager +from functools import partial +from hashlib import sha256 +from pathlib import Path +from typing import BinaryIO, Dict, Optional, Tuple, Union + +from tqdm.auto import tqdm + +import requests +from filelock import FileLock + +from . import __version__ +from .hf_api import HfFolder + + +logger = logging.getLogger(__name__) + + +try: + import torch + + _torch_available = True +except ImportError: + _torch_available = False + +try: + import tensorflow as tf + + _tf_available = True +except ImportError: + _tf_available = False + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +# Constants for file downloads + +PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" +CONFIG_NAME = "config.json" + +HUGGINGFACE_CO_URL_TEMPLATE = ( + "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" +) + + +# default cache +hf_cache_home = os.path.expanduser( + os.getenv( + "HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface") + ) +) +default_cache_path = os.path.join(hf_cache_home, "hub") + +HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) + + +def hf_hub_url( + model_id: str, + filename: str, + subfolder: Optional[str] = None, + revision: Optional[str] = None, +) -> str: + """ + Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting + to Cloudfront (a Content Delivery Network, or CDN) for large files (more than a few MBs). + + Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our + bandwidth costs). + + Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here + because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront + in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache + can't ever be stale. + + In terms of client-side caching from this library, we base our caching on the objects' ETag. An object's ETag is: + its git-sha1 if stored in git, or its sha256 if stored in git-lfs. + """ + if subfolder is not None: + filename = f"{subfolder}/{filename}" + + if revision is None: + revision = "main" + return HUGGINGFACE_CO_URL_TEMPLATE.format( + model_id=model_id, revision=revision, filename=filename + ) + + +def url_to_filename(url: str, etag: Optional[str] = None) -> str: + """ + Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, + delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can + identify it as a HDF5 file (see + https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) + """ + url_bytes = url.encode("utf-8") + filename = sha256(url_bytes).hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + filename += "." + sha256(etag_bytes).hexdigest() + + if url.endswith(".h5"): + filename += ".h5" + + return filename + + +def filename_to_url(filename, cache_dir=None) -> Tuple[str, str]: + """ + Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or + its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def http_user_agent( + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + if library_name is not None: + ua = "{}/{}".format(library_name, library_version) + else: + ua = "unknown/None" + ua += "; hf_hub/{}".format(__version__) + ua += "; python/{}".format(sys.version.split()[0]) + if is_torch_available(): + ua += "; torch/{}".format(torch.__version__) + if is_tf_available(): + ua += "; tensorflow/{}".format(tf.__version__) + if isinstance(user_agent, dict): + ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua + + +def http_get( + url: str, + temp_file: BinaryIO, + proxies=None, + resume_size=0, + headers: Optional[Dict[str, str]] = None, +): + """ + Donwload remote file. Do not gobble up errors. + """ + headers = copy.deepcopy(headers) + if resume_size > 0: + headers["Range"] = "bytes=%d-" % (resume_size,) + r = requests.get(url, stream=True, proxies=proxies, headers=headers) + r.raise_for_status() + content_length = r.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + initial=resume_size, + desc="Downloading", + disable=bool(logger.getEffectiveLevel() == logging.NOTSET), + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def cached_download( + url: str, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + user_agent: Union[Dict, str, None] = None, + force_download=False, + proxies=None, + etag_timeout=10, + resume_download=False, + use_auth_token: Union[str, None] = None, + local_files_only=False, +) -> Optional[str]: # pragma: no cover + """ + Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the + path to the cached file. + + Return: + Local path (string) of file or if networking is off, last version of file cached on disk. + + Raises: + In case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + """ + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + os.makedirs(cache_dir, exist_ok=True) + + headers = { + "user-agent": http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + } + if isinstance(use_auth_token, str): + headers["authorization"] = "Bearer {}".format(use_auth_token) + elif use_auth_token: + token = HfFolder.get_token() + if token is None: + raise EnvironmentError( + "You specified use_auth_token=True, but a huggingface token was not found." + ) + headers["authorization"] = "Bearer {}".format(token) + + url_to_download = url + etag = None + if not local_files_only: + try: + r = requests.head( + url, + headers=headers, + allow_redirects=False, + proxies=proxies, + timeout=etag_timeout, + ) + r.raise_for_status() + etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + # In case of a redirect, + # save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + if 300 <= r.status_code <= 399: + url_to_download = r.headers["Location"] + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + # etag is already None + pass + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # etag is None == we don't have a connection or we passed local_files_only. + # try to get the last downloaded one + if etag is None: + if os.path.exists(cache_path): + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter( + os.listdir(cache_dir), filename.split(".")[0] + ".*" + ) + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0: + return os.path.join(cache_dir, matching_files[-1]) + else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) + else: + raise ValueError( + "Connection error, and we cannot find the requested files in the cached path." + " Please try again or make sure your Internet connection is on." + ) + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + + # If the download just completed while the lock was activated. + if os.path.exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + if resume_download: + incomplete_path = cache_path + ".incomplete" + + @contextmanager + def _resumable_file_manager() -> "io.BufferedWriter": + with open(incomplete_path, "ab") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("downloading %s to %s", url, temp_file.name) + + http_get( + url_to_download, + temp_file, + proxies=proxies, + resume_size=resume_size, + headers=headers, + ) + + logger.info("storing %s in cache at %s", url, cache_path) + os.replace(temp_file.name, cache_path) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py new file mode 100644 index 0000000000..4ff269a2b3 --- /dev/null +++ b/src/huggingface_hub/hf_api.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from os.path import expanduser +from typing import Dict, List, Optional, Tuple + +import requests + + +ENDPOINT = "https://huggingface.co" + + +class RepoObj: + """ + HuggingFace git-based system, data structure that represents a file belonging to the current user. + """ + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class ModelSibling: + """ + Data structure that represents a public file inside a model, accessible from huggingface.co + """ + + def __init__(self, rfilename: str, **kwargs): + self.rfilename = rfilename # filename relative to the model root + for k, v in kwargs.items(): + setattr(self, k, v) + + +class ModelInfo: + """ + Info about a public model accessible from huggingface.co + """ + + def __init__( + self, + modelId: Optional[str] = None, # id of model + tags: List[str] = [], + pipeline_tag: Optional[str] = None, + siblings: Optional[ + List[Dict] + ] = None, # list of files that constitute the model + **kwargs + ): + self.modelId = modelId + self.tags = tags + self.pipeline_tag = pipeline_tag + self.siblings = ( + [ModelSibling(**x) for x in siblings] if siblings is not None else None + ) + for k, v in kwargs.items(): + setattr(self, k, v) + + +class HfApi: + def __init__(self, endpoint=None): + self.endpoint = endpoint if endpoint is not None else ENDPOINT + + def login(self, username: str, password: str) -> str: + """ + Call HF API to sign in a user and get a token if credentials are valid. + + Outputs: token if credentials are valid + + Throws: requests.exceptions.HTTPError if credentials are invalid + """ + path = "{}/api/login".format(self.endpoint) + r = requests.post(path, json={"username": username, "password": password}) + r.raise_for_status() + d = r.json() + return d["token"] + + def whoami(self, token: str) -> Tuple[str, List[str]]: + """ + Call HF API to know "whoami" + """ + path = "{}/api/whoami".format(self.endpoint) + r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) + r.raise_for_status() + d = r.json() + return d["user"], d["orgs"] + + def logout(self, token: str) -> None: + """ + Call HF API to log out. + """ + path = "{}/api/logout".format(self.endpoint) + r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) + r.raise_for_status() + + def model_list(self) -> List[ModelInfo]: + """ + Get the public list of all the models on huggingface.co + """ + path = "{}/api/models".format(self.endpoint) + r = requests.get(path) + r.raise_for_status() + d = r.json() + return [ModelInfo(**x) for x in d] + + def list_repos_objs( + self, token: str, organization: Optional[str] = None + ) -> List[RepoObj]: + """ + HuggingFace git-based system, used for models. + + Call HF API to list all stored files for user (or one of their organizations). + """ + path = "{}/api/repos/ls".format(self.endpoint) + params = {"organization": organization} if organization is not None else None + r = requests.get( + path, params=params, headers={"authorization": "Bearer {}".format(token)} + ) + r.raise_for_status() + d = r.json() + return [RepoObj(**x) for x in d] + + def create_repo( + self, + token: str, + name: str, + organization: Optional[str] = None, + private: Optional[bool] = None, + exist_ok=False, + lfsmultipartthresh: Optional[int] = None, + ) -> str: + """ + HuggingFace git-based system, used for models. + + Call HF API to create a whole repo. + + Params: + private: Whether the model repo should be private (requires a paid huggingface.co account) + + exist_ok: Do not raise an error if repo already exists + + lfsmultipartthresh: Optional: internal param for testing purposes. + """ + path = "{}/api/repos/create".format(self.endpoint) + json = {"name": name, "organization": organization, "private": private} + if lfsmultipartthresh is not None: + json["lfsmultipartthresh"] = lfsmultipartthresh + r = requests.post( + path, + headers={"authorization": "Bearer {}".format(token)}, + json=json, + ) + if exist_ok and r.status_code == 409: + return "" + r.raise_for_status() + d = r.json() + return d["url"] + + def delete_repo(self, token: str, name: str, organization: Optional[str] = None): + """ + HuggingFace git-based system, used for models. + + Call HF API to delete a whole repo. + + CAUTION(this is irreversible). + """ + path = "{}/api/repos/delete".format(self.endpoint) + r = requests.delete( + path, + headers={"authorization": "Bearer {}".format(token)}, + json={"name": name, "organization": organization}, + ) + r.raise_for_status() + + +class HfFolder: + path_token = expanduser("~/.huggingface/token") + + @classmethod + def save_token(cls, token): + """ + Save token, creating folder as needed. + """ + os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) + with open(cls.path_token, "w+") as f: + f.write(token) + + @classmethod + def get_token(cls): + """ + Get token or None if not existent. + """ + try: + with open(cls.path_token, "r") as f: + return f.read() + except FileNotFoundError: + pass + + @classmethod + def delete_token(cls): + """ + Delete token. Do not fail if token does not exist. + """ + try: + os.remove(cls.path_token) + except FileNotFoundError: + pass diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fixtures/empty.txt b/tests/fixtures/empty.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_file_download.py b/tests/test_file_download.py new file mode 100644 index 0000000000..f489f55908 --- /dev/null +++ b/tests/test_file_download.py @@ -0,0 +1,88 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import requests +from huggingface_hub.file_download import ( + CONFIG_NAME, + PYTORCH_WEIGHTS_NAME, + cached_download, + filename_to_url, + hf_hub_url, +) + +from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER + + +MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER +# An actual model hosted on huggingface.co + +REVISION_ID_DEFAULT = "main" +# Default branch name +REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2" +# One particular commit (not the top of `main`) +REVISION_ID_INVALID = "aaaaaaa" +# This commit does not exist, so we should 404. + +PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684" +# Sha-1 of config.json on the top of `main`, for checking purposes +PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3" +# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes + + +class CachedDownloadTests(unittest.TestCase): + def test_bogus_url(self): + # This lets us simulate no connection + # as the error raised is the same + # `ConnectionError` + url = "https://bogus" + with self.assertRaisesRegex(ValueError, "Connection error"): + _ = cached_download(url) + + def test_file_not_found(self): + # Valid revision (None) but missing file. + url = hf_hub_url(MODEL_ID, filename="missing.bin") + with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): + _ = cached_download(url) + + def test_revision_not_found(self): + # Valid file but missing revision + url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID) + with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): + _ = cached_download(url) + + def test_standard_object(self): + url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) + filepath = cached_download(url, force_download=True) + metadata = filename_to_url(filepath) + self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"')) + + def test_standard_object_rev(self): + # Same object, but different revision + url = hf_hub_url( + MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT + ) + filepath = cached_download(url, force_download=True) + metadata = filename_to_url(filepath) + self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"') + # Caution: check that the etag is *not* equal to the one from `test_standard_object` + + def test_lfs_object(self): + url = hf_hub_url( + MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT + ) + filepath = cached_download(url, force_download=True) + metadata = filename_to_url(filepath) + self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"')) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py new file mode 100644 index 0000000000..912b00db96 --- /dev/null +++ b/tests/test_hf_api.py @@ -0,0 +1,223 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +import subprocess +import time +import unittest + +from huggingface_hub.hf_api import HfApi, HfFolder, ModelInfo, RepoObj +from requests.exceptions import HTTPError + +from .testing_utils import require_git_lfs + + +USER = "__DUMMY_TRANSFORMERS_USER__" +PASS = "__DUMMY_TRANSFORMERS_PASS__" + +ENDPOINT_STAGING = "https://moon-staging.huggingface.co" +ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co" + +REPO_NAME = "my-model-{}".format(int(time.time() * 10e3)) +REPO_NAME_LARGE_FILE = "my-model-largefiles-{}".format(int(time.time() * 10e3)) +WORKING_REPO_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo" +) +LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub" +LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf" + + +class HfApiCommonTest(unittest.TestCase): + _api = HfApi(endpoint=ENDPOINT_STAGING) + + +class HfApiLoginTest(HfApiCommonTest): + def test_login_invalid(self): + with self.assertRaises(HTTPError): + self._api.login(username=USER, password="fake") + + def test_login_valid(self): + token = self._api.login(username=USER, password=PASS) + self.assertIsInstance(token, str) + + +class HfApiEndpointsTest(HfApiCommonTest): + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = cls._api.login(username=USER, password=PASS) + + def test_whoami(self): + user, orgs = self._api.whoami(token=self._token) + self.assertEqual(user, USER) + self.assertIsInstance(orgs, list) + + def test_list_repos_objs(self): + objs = self._api.list_repos_objs(token=self._token) + self.assertIsInstance(objs, list) + if len(objs) > 0: + o = objs[-1] + self.assertIsInstance(o, RepoObj) + + def test_create_and_delete_repo(self): + self._api.create_repo(token=self._token, name=REPO_NAME) + self._api.delete_repo(token=self._token, name=REPO_NAME) + + +class HfApiPublicTest(unittest.TestCase): + def test_staging_model_list(self): + _api = HfApi(endpoint=ENDPOINT_STAGING) + _ = _api.model_list() + + def test_model_list(self): + _api = HfApi() + models = _api.model_list() + self.assertGreater(len(models), 100) + self.assertIsInstance(models[0], ModelInfo) + + +class HfFolderTest(unittest.TestCase): + def test_token_workflow(self): + """ + Test the whole token save/get/delete workflow, + with the desired behavior with respect to non-existent tokens. + """ + token = "token-{}".format(int(time.time())) + HfFolder.save_token(token) + self.assertEqual(HfFolder.get_token(), token) + HfFolder.delete_token() + HfFolder.delete_token() + # ^^ not an error, we test that the + # second call does not fail. + self.assertEqual(HfFolder.get_token(), None) + + +@require_git_lfs +class HfLargefilesTest(HfApiCommonTest): + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = cls._api.login(username=USER, password=PASS) + + def setUp(self): + try: + shutil.rmtree(WORKING_REPO_DIR) + except FileNotFoundError: + pass + + def tearDown(self): + self._api.delete_repo(token=self._token, name=REPO_NAME_LARGE_FILE) + + def setup_local_clone(self, REMOTE_URL): + REMOTE_URL_AUTH = REMOTE_URL.replace( + ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH + ) + subprocess.run( + ["git", "clone", REMOTE_URL_AUTH, WORKING_REPO_DIR], + check=True, + capture_output=True, + ) + subprocess.run( + ["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR + ) + subprocess.run( + ["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR + ) + + def test_end_to_end_thresh_6M(self): + REMOTE_URL = self._api.create_repo( + token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=6 * 10 ** 6 + ) + self.setup_local_clone(REMOTE_URL) + + subprocess.run( + ["wget", LARGE_FILE_18MB], + check=True, + capture_output=True, + cwd=WORKING_REPO_DIR, + ) + subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR) + subprocess.run( + ["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR + ) + + # This will fail as we haven't set up our custom transfer agent yet. + failed_process = subprocess.run( + ["git", "push"], capture_output=True, cwd=WORKING_REPO_DIR + ) + self.assertEqual(failed_process.returncode, 1) + self.assertIn("cli lfs-enable-largefiles", failed_process.stderr.decode()) + # ^ Instructions on how to fix this are included in the error message. + + subprocess.run( + ["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True + ) + + start_time = time.time() + subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR) + print("took", time.time() - start_time) + + # To be 100% sure, let's download the resolved file + pdf_url = f"{REMOTE_URL}/resolve/main/progit.pdf" + DEST_FILENAME = "uploaded.pdf" + subprocess.run( + ["wget", pdf_url, "-O", DEST_FILENAME], + check=True, + capture_output=True, + cwd=WORKING_REPO_DIR, + ) + dest_filesize = os.stat(os.path.join(WORKING_REPO_DIR, DEST_FILENAME)).st_size + self.assertEqual(dest_filesize, 18685041) + + def test_end_to_end_thresh_16M(self): + # Here we'll push one multipart and one non-multipart file in the same commit, and see what happens + REMOTE_URL = self._api.create_repo( + token=self._token, + name=REPO_NAME_LARGE_FILE, + lfsmultipartthresh=16 * 10 ** 6, + ) + self.setup_local_clone(REMOTE_URL) + + subprocess.run( + ["wget", LARGE_FILE_18MB], + check=True, + capture_output=True, + cwd=WORKING_REPO_DIR, + ) + subprocess.run( + ["wget", LARGE_FILE_14MB], + check=True, + capture_output=True, + cwd=WORKING_REPO_DIR, + ) + subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR) + subprocess.run( + ["git", "commit", "-m", "both files in same commit"], + check=True, + cwd=WORKING_REPO_DIR, + ) + + subprocess.run( + ["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True + ) + + start_time = time.time() + subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR) + print("took", time.time() - start_time) diff --git a/tests/testing_utils.py b/tests/testing_utils.py new file mode 100644 index 0000000000..fc765815cb --- /dev/null +++ b/tests/testing_utils.py @@ -0,0 +1,54 @@ +import os +import unittest +from distutils.util import strtobool + + +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" +DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" +# Example model ids + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError("If set, {} must be yes or no.".format(key)) + return _value + + +def parse_int_from_env(key, default=None): + try: + value = os.environ[key] + except KeyError: + _value = default + else: + try: + _value = int(value) + except ValueError: + raise ValueError("If set, {} must be a int.".format(key)) + return _value + + +_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) + + +def require_git_lfs(test_case): + """ + Decorator marking a test that requires git-lfs. + + git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment + variable to a truthy value to run them. + """ + if not _run_git_lfs_tests: + return unittest.skip("test of git lfs workflow")(test_case) + else: + return test_case