-
Notifications
You must be signed in to change notification settings - Fork 6k
add load textual inversion embeddings to stable diffusion #2009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6012e93
a3a800b
d4642c7
ca6d38d
c5ffdc3
32391af
525428d
912c7c3
15206c3
fdec2d0
e01a3f8
5ec8fea
5d58240
530a208
8e50514
b730987
65b76f8
66a7489
82dff21
bf0424b
22e4751
f25292c
f231854
ced8e14
5d2ef24
e9284a4
e6f6d1c
bd3b595
22abd33
6f9c186
0be8c24
f68a5f6
baaf3df
314c1e2
719e6a7
3790d31
ef8ab03
5939c86
531d61a
32c86b5
04c91ba
23a36ef
f090898
f5b6ff1
8a040e8
835a8d0
08a85dc
d172099
991d3d7
28c425b
df9f579
e101d9a
9dd0267
74b1e64
b9f53cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,25 +13,38 @@ | |||||||||
# limitations under the License. | ||||||||||
import os | ||||||||||
from collections import defaultdict | ||||||||||
from typing import Callable, Dict, Union | ||||||||||
from typing import Callable, Dict, List, Optional, Union | ||||||||||
|
||||||||||
import torch | ||||||||||
|
||||||||||
from .models.attention_processor import LoRAAttnProcessor | ||||||||||
from .models.modeling_utils import _get_model_file | ||||||||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging | ||||||||||
from .utils import ( | ||||||||||
DIFFUSERS_CACHE, | ||||||||||
HF_HUB_OFFLINE, | ||||||||||
_get_model_file, | ||||||||||
deprecate, | ||||||||||
is_safetensors_available, | ||||||||||
is_transformers_available, | ||||||||||
logging, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
if is_safetensors_available(): | ||||||||||
import safetensors | ||||||||||
|
||||||||||
if is_transformers_available(): | ||||||||||
from transformers import PreTrainedModel, PreTrainedTokenizer | ||||||||||
|
||||||||||
|
||||||||||
logger = logging.get_logger(__name__) | ||||||||||
|
||||||||||
|
||||||||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | ||||||||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" | ||||||||||
|
||||||||||
TEXT_INVERSION_NAME = "learned_embeds.bin" | ||||||||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" | ||||||||||
|
||||||||||
|
||||||||||
class AttnProcsLayers(torch.nn.Module): | ||||||||||
def __init__(self, state_dict: Dict[str, torch.Tensor]): | ||||||||||
|
@@ -123,13 +136,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |||||||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated | ||||||||||
models](https://huggingface.co/docs/hub/models-gated#gated-models). | ||||||||||
|
||||||||||
</Tip> | ||||||||||
|
||||||||||
<Tip> | ||||||||||
|
||||||||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use | ||||||||||
this method in a firewalled environment. | ||||||||||
|
||||||||||
</Tip> | ||||||||||
""" | ||||||||||
|
||||||||||
|
@@ -292,5 +298,272 @@ def save_function(weights, filename): | |||||||||
|
||||||||||
# Save the model | ||||||||||
save_function(state_dict, os.path.join(save_directory, weight_name)) | ||||||||||
|
||||||||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") | ||||||||||
|
||||||||||
|
||||||||||
class TextualInversionLoaderMixin: | ||||||||||
r""" | ||||||||||
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder. | ||||||||||
""" | ||||||||||
|
||||||||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer): | ||||||||||
r""" | ||||||||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds | ||||||||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token | ||||||||||
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual | ||||||||||
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. | ||||||||||
|
||||||||||
Parameters: | ||||||||||
prompt (`str` or list of `str`): | ||||||||||
The prompt or prompts to guide the image generation. | ||||||||||
tokenizer (`PreTrainedTokenizer`): | ||||||||||
The tokenizer responsible for encoding the prompt into input tokens. | ||||||||||
|
||||||||||
Returns: | ||||||||||
`str` or list of `str`: The converted prompt | ||||||||||
""" | ||||||||||
if not isinstance(prompt, List): | ||||||||||
prompts = [prompt] | ||||||||||
else: | ||||||||||
prompts = prompt | ||||||||||
|
||||||||||
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] | ||||||||||
|
||||||||||
if not isinstance(prompt, List): | ||||||||||
return prompts[0] | ||||||||||
|
||||||||||
return prompts | ||||||||||
|
||||||||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer): | ||||||||||
r""" | ||||||||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds | ||||||||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token | ||||||||||
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual | ||||||||||
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. | ||||||||||
|
||||||||||
Parameters: | ||||||||||
prompt (`str`): | ||||||||||
The prompt to guide the image generation. | ||||||||||
tokenizer (`PreTrainedTokenizer`): | ||||||||||
The tokenizer responsible for encoding the prompt into input tokens. | ||||||||||
|
||||||||||
Returns: | ||||||||||
`str`: The converted prompt | ||||||||||
""" | ||||||||||
tokens = tokenizer.tokenize(prompt) | ||||||||||
for token in tokens: | ||||||||||
if token in tokenizer.added_tokens_encoder: | ||||||||||
replacement = token | ||||||||||
i = 1 | ||||||||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder: | ||||||||||
replacement += f"{token}_{i}" | ||||||||||
i += 1 | ||||||||||
|
||||||||||
prompt = prompt.replace(token, replacement) | ||||||||||
|
||||||||||
return prompt | ||||||||||
|
||||||||||
def load_textual_inversion( | ||||||||||
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs | ||||||||||
): | ||||||||||
r""" | ||||||||||
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and | ||||||||||
`Automatic1111` formats are supported. | ||||||||||
|
||||||||||
<Tip warning={true}> | ||||||||||
|
||||||||||
This function is experimental and might change in the future. | ||||||||||
|
||||||||||
</Tip> | ||||||||||
|
||||||||||
Parameters: | ||||||||||
pretrained_model_name_or_path (`str` or `os.PathLike`): | ||||||||||
Can be either: | ||||||||||
|
||||||||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. | ||||||||||
Valid model ids should have an organization name, like | ||||||||||
`"sd-concepts-library/low-poly-hd-logos-icons"`. | ||||||||||
- A path to a *directory* containing textual inversion weights, e.g. | ||||||||||
`./my_text_inversion_directory/`. | ||||||||||
weight_name (`str`, *optional*): | ||||||||||
Name of a custom weight file. This should be used in two cases: | ||||||||||
|
||||||||||
- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight | ||||||||||
name, such as `text_inv.bin`. | ||||||||||
- The saved textual inversion file is in the "Automatic1111" form. | ||||||||||
cache_dir (`Union[str, os.PathLike]`, *optional*): | ||||||||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | ||||||||||
standard cache should not be used. | ||||||||||
force_download (`bool`, *optional*, defaults to `False`): | ||||||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||||||||||
cached versions if they exist. | ||||||||||
resume_download (`bool`, *optional*, defaults to `False`): | ||||||||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | ||||||||||
file exists. | ||||||||||
proxies (`Dict[str, str]`, *optional*): | ||||||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | ||||||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||||||||||
local_files_only(`bool`, *optional*, defaults to `False`): | ||||||||||
Whether or not to only look at local files (i.e., do not try to download the model). | ||||||||||
use_auth_token (`str` or *bool*, *optional*): | ||||||||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated | ||||||||||
when running `diffusers-cli login` (stored in `~/.huggingface`). | ||||||||||
revision (`str`, *optional*, defaults to `"main"`): | ||||||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | ||||||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | ||||||||||
identifier allowed by git. | ||||||||||
subfolder (`str`, *optional*, defaults to `""`): | ||||||||||
In case the relevant files are located inside a subfolder of the model repo (either remote in | ||||||||||
huggingface.co or downloaded locally), you can specify the folder name here. | ||||||||||
|
||||||||||
mirror (`str`, *optional*): | ||||||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility | ||||||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. | ||||||||||
Please refer to the mirror site for more information. | ||||||||||
|
||||||||||
<Tip> | ||||||||||
|
||||||||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated | ||||||||||
models](https://huggingface.co/docs/hub/models-gated#gated-models). | ||||||||||
|
||||||||||
</Tip> | ||||||||||
""" | ||||||||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): | ||||||||||
raise ValueError( | ||||||||||
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" | ||||||||||
f" `{self.load_textual_inversion.__name__}`" | ||||||||||
) | ||||||||||
|
||||||||||
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): | ||||||||||
raise ValueError( | ||||||||||
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" | ||||||||||
f" `{self.load_textual_inversion.__name__}`" | ||||||||||
) | ||||||||||
|
||||||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | ||||||||||
force_download = kwargs.pop("force_download", False) | ||||||||||
resume_download = kwargs.pop("resume_download", False) | ||||||||||
proxies = kwargs.pop("proxies", None) | ||||||||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | ||||||||||
use_auth_token = kwargs.pop("use_auth_token", None) | ||||||||||
revision = kwargs.pop("revision", None) | ||||||||||
subfolder = kwargs.pop("subfolder", None) | ||||||||||
weight_name = kwargs.pop("weight_name", None) | ||||||||||
use_safetensors = kwargs.pop("use_safetensors", None) | ||||||||||
|
||||||||||
if use_safetensors and not is_safetensors_available(): | ||||||||||
raise ValueError( | ||||||||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" | ||||||||||
) | ||||||||||
|
||||||||||
allow_pickle = False | ||||||||||
if use_safetensors is None: | ||||||||||
use_safetensors = is_safetensors_available() | ||||||||||
allow_pickle = True | ||||||||||
|
||||||||||
user_agent = { | ||||||||||
"file_type": "text_inversion", | ||||||||||
"framework": "pytorch", | ||||||||||
} | ||||||||||
|
||||||||||
# 1. Load textual inversion file | ||||||||||
model_file = None | ||||||||||
# Let's first try to load .safetensors weights | ||||||||||
if (use_safetensors and weight_name is None) or ( | ||||||||||
weight_name is not None and weight_name.endswith(".safetensors") | ||||||||||
): | ||||||||||
try: | ||||||||||
model_file = _get_model_file( | ||||||||||
pretrained_model_name_or_path, | ||||||||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, | ||||||||||
cache_dir=cache_dir, | ||||||||||
force_download=force_download, | ||||||||||
resume_download=resume_download, | ||||||||||
proxies=proxies, | ||||||||||
local_files_only=local_files_only, | ||||||||||
use_auth_token=use_auth_token, | ||||||||||
revision=revision, | ||||||||||
subfolder=subfolder, | ||||||||||
user_agent=user_agent, | ||||||||||
) | ||||||||||
state_dict = safetensors.torch.load_file(model_file, device="cpu") | ||||||||||
except Exception as e: | ||||||||||
if not allow_pickle: | ||||||||||
raise e | ||||||||||
|
||||||||||
model_file = None | ||||||||||
|
||||||||||
if model_file is None: | ||||||||||
model_file = _get_model_file( | ||||||||||
pretrained_model_name_or_path, | ||||||||||
weights_name=weight_name or TEXT_INVERSION_NAME, | ||||||||||
cache_dir=cache_dir, | ||||||||||
force_download=force_download, | ||||||||||
resume_download=resume_download, | ||||||||||
proxies=proxies, | ||||||||||
local_files_only=local_files_only, | ||||||||||
use_auth_token=use_auth_token, | ||||||||||
revision=revision, | ||||||||||
subfolder=subfolder, | ||||||||||
user_agent=user_agent, | ||||||||||
) | ||||||||||
state_dict = torch.load(model_file, map_location="cpu") | ||||||||||
|
||||||||||
# 2. Load token and embedding correcly from file | ||||||||||
if isinstance(state_dict, torch.Tensor): | ||||||||||
if token is None: | ||||||||||
raise ValueError( | ||||||||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." | ||||||||||
) | ||||||||||
embedding = state_dict | ||||||||||
elif len(state_dict) == 1: | ||||||||||
# diffusers | ||||||||||
loaded_token, embedding = next(iter(state_dict.items())) | ||||||||||
elif "string_to_param" in state_dict: | ||||||||||
# A1111 | ||||||||||
loaded_token = state_dict["name"] | ||||||||||
embedding = state_dict["string_to_param"]["*"] | ||||||||||
|
||||||||||
if token is not None and loaded_token != token: | ||||||||||
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") | ||||||||||
else: | ||||||||||
token = loaded_token | ||||||||||
|
||||||||||
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) | ||||||||||
|
||||||||||
# 3. Make sure we don't mess up the tokenizer or text encoder | ||||||||||
vocab = self.tokenizer.get_vocab() | ||||||||||
if token in vocab: | ||||||||||
raise ValueError( | ||||||||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." | ||||||||||
) | ||||||||||
elif f"{token}_1" in vocab: | ||||||||||
multi_vector_tokens = [token] | ||||||||||
i = 1 | ||||||||||
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: | ||||||||||
multi_vector_tokens.append(f"{token}_{i}") | ||||||||||
i += 1 | ||||||||||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
raise ValueError( | ||||||||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." | ||||||||||
) | ||||||||||
|
||||||||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 | ||||||||||
|
||||||||||
if is_multi_vector: | ||||||||||
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] | ||||||||||
embeddings = [e for e in embedding] # noqa: C416 | ||||||||||
else: | ||||||||||
tokens = [token] | ||||||||||
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was trying the latest version and I wasn't getting anything related to the embeddings, after changing this I was able to get good results. I'm not sure how this would work with
Suggested change
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm cannot reproduce this one - my tests are passing just fine on this branch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed pretrained_path = 'xxx'
embedding_path = 'learned_embeds.bin'
pipe = DiffusionPipeline.from_pretrained(pretrained_path, torch_dtype=torch.float16)
pipe.load_textual_inversion(embedding_path) It is not working to show any concept from my enbedding token. I have to modify the same with @GuiyeC for the file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created a PR with the fix for this where I try to explain the problem a bit more. |
||||||||||
|
||||||||||
# add tokens and get ids | ||||||||||
self.tokenizer.add_tokens(tokens) | ||||||||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) | ||||||||||
|
||||||||||
# resize token embeddings and set new embeddings | ||||||||||
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) | ||||||||||
for token_id, embedding in zip(token_ids, embeddings): | ||||||||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding | ||||||||||
|
||||||||||
logger.info("Loaded textual inversion embedding for {token}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we want to do the opposite override? (What comes in the state_dict is what gets added)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, I'd say what gets passed has priority! If you do:
I think the token should be
"<special-token>"
no matter what's in the dict - it's similar to how we dofrom_pretrained(unet=unet)
overrides