diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index a4a25fedba79..3325b72111c3 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -13,11 +13,10 @@ # limitations under the License. import os +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home -hf_cache_home = os.path.expanduser( - os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) -) -default_cache_path = os.path.join(hf_cache_home, "diffusers") + +default_cache_path = HUGGINGFACE_HUB_CACHE CONFIG_NAME = "config.json" diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 7e6bd7870de7..ff715595a3b0 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -16,6 +16,7 @@ import os import sys +import traceback from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 @@ -24,7 +25,7 @@ from huggingface_hub.utils import is_jinja_available from .. import __version__ -from .constants import HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, @@ -129,3 +130,72 @@ def create_model_card(args, model_name): card_path = os.path.join(args.output_dir, "README.md") model_card.save(card_path) + + +# Old default cache path, potentially to be migrated. +# This logic was more or less taken from `transformers`, with the following differences: +# - Diffusers doesn't use custom environment variables to specify the cache path. +# - There is no need to migrate the cache format, just move the files to the new location. +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") + + +def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: + if new_cache_dir is None: + new_cache_dir = DIFFUSERS_CACHE + if old_cache_dir is None: + old_cache_dir = old_diffusers_cache + + old_cache_dir = Path(old_cache_dir).expanduser() + new_cache_dir = Path(new_cache_dir).expanduser() + for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob + if old_blob_path.is_file() and not old_blob_path.is_symlink(): + new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) + new_blob_path.parent.mkdir(parents=True, exist_ok=True) + os.replace(old_blob_path, new_blob_path) + try: + os.symlink(new_blob_path, old_blob_path) + except OSError: + logger.warning( + "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." + ) + # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). + + +cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") +if not os.path.isfile(cache_version_file): + cache_version = 0 +else: + with open(cache_version_file) as f: + cache_version = int(f.read()) + +if cache_version < 1: + old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 + if old_cache_is_not_empty: + logger.warning( + "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " + "existing cached models. This is a one-time operation, you can interrupt it or run it " + "later by calling `diffusers.utils.hub_utils.move_cache()`." + ) + try: + move_cache() + except Exception as e: + trace = "\n".join(traceback.format_tb(e.__traceback__)) + logger.error( + f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " + "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " + "message and we will do our best to help." + ) + +if cache_version < 1: + try: + os.makedirs(DIFFUSERS_CACHE, exist_ok=True) + with open(cache_version_file, "w") as f: + f.write("1") + except Exception: + logger.warning( + f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " + "the directory exists and can be written to." + )