|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
15 | 15 | from collections import defaultdict
|
16 |
| -from typing import Callable, Dict, Union |
| 16 | +from typing import Callable, Dict, List, Optional, Union |
17 | 17 |
|
18 | 18 | import torch
|
19 | 19 |
|
20 | 20 | from .models.attention_processor import LoRAAttnProcessor
|
21 |
| -from .models.modeling_utils import _get_model_file |
22 |
| -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging |
| 21 | +from .utils import ( |
| 22 | + DIFFUSERS_CACHE, |
| 23 | + HF_HUB_OFFLINE, |
| 24 | + _get_model_file, |
| 25 | + deprecate, |
| 26 | + is_safetensors_available, |
| 27 | + is_transformers_available, |
| 28 | + logging, |
| 29 | +) |
23 | 30 |
|
24 | 31 |
|
25 | 32 | if is_safetensors_available():
|
26 | 33 | import safetensors
|
27 | 34 |
|
| 35 | +if is_transformers_available(): |
| 36 | + from transformers import PreTrainedModel, PreTrainedTokenizer |
| 37 | + |
28 | 38 |
|
29 | 39 | logger = logging.get_logger(__name__)
|
30 | 40 |
|
31 | 41 |
|
32 | 42 | LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
33 | 43 | LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
34 | 44 |
|
| 45 | +TEXT_INVERSION_NAME = "learned_embeds.bin" |
| 46 | +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" |
| 47 | + |
35 | 48 |
|
36 | 49 | class AttnProcsLayers(torch.nn.Module):
|
37 | 50 | def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
@@ -123,13 +136,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
|
123 | 136 | It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
124 | 137 | models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
125 | 138 |
|
126 |
| - </Tip> |
127 |
| -
|
128 |
| - <Tip> |
129 |
| -
|
130 |
| - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use |
131 |
| - this method in a firewalled environment. |
132 |
| -
|
133 | 139 | </Tip>
|
134 | 140 | """
|
135 | 141 |
|
@@ -292,5 +298,272 @@ def save_function(weights, filename):
|
292 | 298 |
|
293 | 299 | # Save the model
|
294 | 300 | save_function(state_dict, os.path.join(save_directory, weight_name))
|
295 |
| - |
296 | 301 | logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
| 302 | + |
| 303 | + |
| 304 | +class TextualInversionLoaderMixin: |
| 305 | + r""" |
| 306 | + Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder. |
| 307 | + """ |
| 308 | + |
| 309 | + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer): |
| 310 | + r""" |
| 311 | + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds |
| 312 | + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token |
| 313 | + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual |
| 314 | + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. |
| 315 | +
|
| 316 | + Parameters: |
| 317 | + prompt (`str` or list of `str`): |
| 318 | + The prompt or prompts to guide the image generation. |
| 319 | + tokenizer (`PreTrainedTokenizer`): |
| 320 | + The tokenizer responsible for encoding the prompt into input tokens. |
| 321 | +
|
| 322 | + Returns: |
| 323 | + `str` or list of `str`: The converted prompt |
| 324 | + """ |
| 325 | + if not isinstance(prompt, List): |
| 326 | + prompts = [prompt] |
| 327 | + else: |
| 328 | + prompts = prompt |
| 329 | + |
| 330 | + prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] |
| 331 | + |
| 332 | + if not isinstance(prompt, List): |
| 333 | + return prompts[0] |
| 334 | + |
| 335 | + return prompts |
| 336 | + |
| 337 | + def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer): |
| 338 | + r""" |
| 339 | + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds |
| 340 | + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token |
| 341 | + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual |
| 342 | + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. |
| 343 | +
|
| 344 | + Parameters: |
| 345 | + prompt (`str`): |
| 346 | + The prompt to guide the image generation. |
| 347 | + tokenizer (`PreTrainedTokenizer`): |
| 348 | + The tokenizer responsible for encoding the prompt into input tokens. |
| 349 | +
|
| 350 | + Returns: |
| 351 | + `str`: The converted prompt |
| 352 | + """ |
| 353 | + tokens = tokenizer.tokenize(prompt) |
| 354 | + for token in tokens: |
| 355 | + if token in tokenizer.added_tokens_encoder: |
| 356 | + replacement = token |
| 357 | + i = 1 |
| 358 | + while f"{token}_{i}" in tokenizer.added_tokens_encoder: |
| 359 | + replacement += f"{token}_{i}" |
| 360 | + i += 1 |
| 361 | + |
| 362 | + prompt = prompt.replace(token, replacement) |
| 363 | + |
| 364 | + return prompt |
| 365 | + |
| 366 | + def load_textual_inversion( |
| 367 | + self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs |
| 368 | + ): |
| 369 | + r""" |
| 370 | + Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and |
| 371 | + `Automatic1111` formats are supported. |
| 372 | +
|
| 373 | + <Tip warning={true}> |
| 374 | +
|
| 375 | + This function is experimental and might change in the future. |
| 376 | +
|
| 377 | + </Tip> |
| 378 | +
|
| 379 | + Parameters: |
| 380 | + pretrained_model_name_or_path (`str` or `os.PathLike`): |
| 381 | + Can be either: |
| 382 | +
|
| 383 | + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| 384 | + Valid model ids should have an organization name, like |
| 385 | + `"sd-concepts-library/low-poly-hd-logos-icons"`. |
| 386 | + - A path to a *directory* containing textual inversion weights, e.g. |
| 387 | + `./my_text_inversion_directory/`. |
| 388 | + weight_name (`str`, *optional*): |
| 389 | + Name of a custom weight file. This should be used in two cases: |
| 390 | +
|
| 391 | + - The saved textual inversion file is in `diffusers` format, but was saved under a specific weight |
| 392 | + name, such as `text_inv.bin`. |
| 393 | + - The saved textual inversion file is in the "Automatic1111" form. |
| 394 | + cache_dir (`Union[str, os.PathLike]`, *optional*): |
| 395 | + Path to a directory in which a downloaded pretrained model configuration should be cached if the |
| 396 | + standard cache should not be used. |
| 397 | + force_download (`bool`, *optional*, defaults to `False`): |
| 398 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| 399 | + cached versions if they exist. |
| 400 | + resume_download (`bool`, *optional*, defaults to `False`): |
| 401 | + Whether or not to delete incompletely received files. Will attempt to resume the download if such a |
| 402 | + file exists. |
| 403 | + proxies (`Dict[str, str]`, *optional*): |
| 404 | + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', |
| 405 | + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
| 406 | + local_files_only(`bool`, *optional*, defaults to `False`): |
| 407 | + Whether or not to only look at local files (i.e., do not try to download the model). |
| 408 | + use_auth_token (`str` or *bool*, *optional*): |
| 409 | + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated |
| 410 | + when running `diffusers-cli login` (stored in `~/.huggingface`). |
| 411 | + revision (`str`, *optional*, defaults to `"main"`): |
| 412 | + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
| 413 | + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any |
| 414 | + identifier allowed by git. |
| 415 | + subfolder (`str`, *optional*, defaults to `""`): |
| 416 | + In case the relevant files are located inside a subfolder of the model repo (either remote in |
| 417 | + huggingface.co or downloaded locally), you can specify the folder name here. |
| 418 | +
|
| 419 | + mirror (`str`, *optional*): |
| 420 | + Mirror source to accelerate downloads in China. If you are from China and have an accessibility |
| 421 | + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. |
| 422 | + Please refer to the mirror site for more information. |
| 423 | +
|
| 424 | + <Tip> |
| 425 | +
|
| 426 | + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated |
| 427 | + models](https://huggingface.co/docs/hub/models-gated#gated-models). |
| 428 | +
|
| 429 | + </Tip> |
| 430 | + """ |
| 431 | + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): |
| 432 | + raise ValueError( |
| 433 | + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" |
| 434 | + f" `{self.load_textual_inversion.__name__}`" |
| 435 | + ) |
| 436 | + |
| 437 | + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): |
| 438 | + raise ValueError( |
| 439 | + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" |
| 440 | + f" `{self.load_textual_inversion.__name__}`" |
| 441 | + ) |
| 442 | + |
| 443 | + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| 444 | + force_download = kwargs.pop("force_download", False) |
| 445 | + resume_download = kwargs.pop("resume_download", False) |
| 446 | + proxies = kwargs.pop("proxies", None) |
| 447 | + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) |
| 448 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 449 | + revision = kwargs.pop("revision", None) |
| 450 | + subfolder = kwargs.pop("subfolder", None) |
| 451 | + weight_name = kwargs.pop("weight_name", None) |
| 452 | + use_safetensors = kwargs.pop("use_safetensors", None) |
| 453 | + |
| 454 | + if use_safetensors and not is_safetensors_available(): |
| 455 | + raise ValueError( |
| 456 | + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" |
| 457 | + ) |
| 458 | + |
| 459 | + allow_pickle = False |
| 460 | + if use_safetensors is None: |
| 461 | + use_safetensors = is_safetensors_available() |
| 462 | + allow_pickle = True |
| 463 | + |
| 464 | + user_agent = { |
| 465 | + "file_type": "text_inversion", |
| 466 | + "framework": "pytorch", |
| 467 | + } |
| 468 | + |
| 469 | + # 1. Load textual inversion file |
| 470 | + model_file = None |
| 471 | + # Let's first try to load .safetensors weights |
| 472 | + if (use_safetensors and weight_name is None) or ( |
| 473 | + weight_name is not None and weight_name.endswith(".safetensors") |
| 474 | + ): |
| 475 | + try: |
| 476 | + model_file = _get_model_file( |
| 477 | + pretrained_model_name_or_path, |
| 478 | + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, |
| 479 | + cache_dir=cache_dir, |
| 480 | + force_download=force_download, |
| 481 | + resume_download=resume_download, |
| 482 | + proxies=proxies, |
| 483 | + local_files_only=local_files_only, |
| 484 | + use_auth_token=use_auth_token, |
| 485 | + revision=revision, |
| 486 | + subfolder=subfolder, |
| 487 | + user_agent=user_agent, |
| 488 | + ) |
| 489 | + state_dict = safetensors.torch.load_file(model_file, device="cpu") |
| 490 | + except Exception as e: |
| 491 | + if not allow_pickle: |
| 492 | + raise e |
| 493 | + |
| 494 | + model_file = None |
| 495 | + |
| 496 | + if model_file is None: |
| 497 | + model_file = _get_model_file( |
| 498 | + pretrained_model_name_or_path, |
| 499 | + weights_name=weight_name or TEXT_INVERSION_NAME, |
| 500 | + cache_dir=cache_dir, |
| 501 | + force_download=force_download, |
| 502 | + resume_download=resume_download, |
| 503 | + proxies=proxies, |
| 504 | + local_files_only=local_files_only, |
| 505 | + use_auth_token=use_auth_token, |
| 506 | + revision=revision, |
| 507 | + subfolder=subfolder, |
| 508 | + user_agent=user_agent, |
| 509 | + ) |
| 510 | + state_dict = torch.load(model_file, map_location="cpu") |
| 511 | + |
| 512 | + # 2. Load token and embedding correcly from file |
| 513 | + if isinstance(state_dict, torch.Tensor): |
| 514 | + if token is None: |
| 515 | + raise ValueError( |
| 516 | + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." |
| 517 | + ) |
| 518 | + embedding = state_dict |
| 519 | + elif len(state_dict) == 1: |
| 520 | + # diffusers |
| 521 | + loaded_token, embedding = next(iter(state_dict.items())) |
| 522 | + elif "string_to_param" in state_dict: |
| 523 | + # A1111 |
| 524 | + loaded_token = state_dict["name"] |
| 525 | + embedding = state_dict["string_to_param"]["*"] |
| 526 | + |
| 527 | + if token is not None and loaded_token != token: |
| 528 | + logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") |
| 529 | + else: |
| 530 | + token = loaded_token |
| 531 | + |
| 532 | + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) |
| 533 | + |
| 534 | + # 3. Make sure we don't mess up the tokenizer or text encoder |
| 535 | + vocab = self.tokenizer.get_vocab() |
| 536 | + if token in vocab: |
| 537 | + raise ValueError( |
| 538 | + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." |
| 539 | + ) |
| 540 | + elif f"{token}_1" in vocab: |
| 541 | + multi_vector_tokens = [token] |
| 542 | + i = 1 |
| 543 | + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: |
| 544 | + multi_vector_tokens.append(f"{token}_{i}") |
| 545 | + i += 1 |
| 546 | + |
| 547 | + raise ValueError( |
| 548 | + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." |
| 549 | + ) |
| 550 | + |
| 551 | + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 |
| 552 | + |
| 553 | + if is_multi_vector: |
| 554 | + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] |
| 555 | + embeddings = [e for e in embedding] # noqa: C416 |
| 556 | + else: |
| 557 | + tokens = [token] |
| 558 | + embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]] |
| 559 | + |
| 560 | + # add tokens and get ids |
| 561 | + self.tokenizer.add_tokens(tokens) |
| 562 | + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
| 563 | + |
| 564 | + # resize token embeddings and set new embeddings |
| 565 | + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) |
| 566 | + for token_id, embedding in zip(token_ids, embeddings): |
| 567 | + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding |
| 568 | + |
| 569 | + logger.info("Loaded textual inversion embedding for {token}.") |
0 commit comments