Skip to content

Use "hub" directory for cache instead of "diffusers" #2005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 28, 2023
7 changes: 3 additions & 4 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import os

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home

hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "diffusers")

default_cache_path = HUGGINGFACE_HUB_CACHE


CONFIG_NAME = "config.json"
Expand Down
72 changes: 71 additions & 1 deletion src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import sys
import traceback
from pathlib import Path
from typing import Dict, Optional, Union
from uuid import uuid4
Expand All @@ -24,7 +25,7 @@
from huggingface_hub.utils import is_jinja_available

from .. import __version__
from .constants import HUGGINGFACE_CO_RESOLVE_ENDPOINT
from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
from .import_utils import (
ENV_VARS_TRUE_VALUES,
_flax_version,
Expand Down Expand Up @@ -129,3 +130,72 @@ def create_model_card(args, model_name):

card_path = os.path.join(args.output_dir, "README.md")
model_card.save(card_path)


# Old default cache path, potentially to be migrated.
# This logic was more or less taken from `transformers`, with the following differences:
# - Diffusers doesn't use custom environment variables to specify the cache path.
# - There is no need to migrate the cache format, just move the files to the new location.
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")


def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
if new_cache_dir is None:
new_cache_dir = DIFFUSERS_CACHE
if old_cache_dir is None:
old_cache_dir = old_diffusers_cache

old_cache_dir = Path(old_cache_dir).expanduser()
new_cache_dir = Path(new_cache_dir).expanduser()
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob
if old_blob_path.is_file() and not old_blob_path.is_symlink():
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
os.replace(old_blob_path, new_blob_path)
try:
os.symlink(new_blob_path, old_blob_path)
except OSError:
logger.warning(
"Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
)
# At this point, old_cache_dir contains symlinks to the new cache (it can still be used).


cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
if not os.path.isfile(cache_version_file):
cache_version = 0
else:
with open(cache_version_file) as f:
cache_version = int(f.read())

if cache_version < 1:
old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
if old_cache_is_not_empty:
logger.warning(
"The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
"existing cached models. This is a one-time operation, you can interrupt it or run it "
"later by calling `diffusers.utils.hub_utils.move_cache()`."
Comment on lines +178 to +180
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to also spit out the actual location in this message for ultra transparency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, perhaps we can include that too.

)
try:
move_cache()
except Exception as e:
trace = "\n".join(traceback.format_tb(e.__traceback__))
logger.error(
f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
"file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
"message and we will do our best to help."
)

if cache_version < 1:
try:
os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
with open(cache_version_file, "w") as f:
f.write("1")
except Exception:
logger.warning(
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
"the directory exists and can be written to."
)