Skip to content

add load textual inversion embeddings to stable diffusion #2009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
6012e93
add load textual inversion embeddings draft
piEsposito Jan 16, 2023
a3a800b
Merge branch 'main' into main
piEsposito Jan 16, 2023
d4642c7
fix quality
piEsposito Jan 16, 2023
ca6d38d
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 16, 2023
c5ffdc3
fix typo
piEsposito Jan 16, 2023
32391af
Merge branch 'main' into main
piEsposito Jan 16, 2023
525428d
make fix copies
piEsposito Jan 16, 2023
912c7c3
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 16, 2023
15206c3
Merge branch 'huggingface:main' into main
piEsposito Jan 17, 2023
fdec2d0
move to textual inversion mixin
piEsposito Jan 17, 2023
e01a3f8
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 17, 2023
5ec8fea
make it accept from sd-concept library
piEsposito Jan 17, 2023
5d58240
accept list of paths to embeddings
piEsposito Jan 17, 2023
530a208
fix styling of stable diffusion pipeline
piEsposito Jan 17, 2023
8e50514
add dummy TextualInversionMixin
piEsposito Jan 17, 2023
b730987
add docstring to textualinversionmixin
piEsposito Jan 17, 2023
65b76f8
add load textual inversion embeddings draft
piEsposito Jan 16, 2023
66a7489
fix quality
piEsposito Jan 16, 2023
82dff21
fix typo
piEsposito Jan 16, 2023
bf0424b
make fix copies
piEsposito Jan 16, 2023
22e4751
move to textual inversion mixin
piEsposito Jan 17, 2023
f25292c
make it accept from sd-concept library
piEsposito Jan 17, 2023
f231854
accept list of paths to embeddings
piEsposito Jan 17, 2023
ced8e14
fix styling of stable diffusion pipeline
piEsposito Jan 17, 2023
5d2ef24
add dummy TextualInversionMixin
piEsposito Jan 17, 2023
e9284a4
add docstring to textualinversionmixin
piEsposito Jan 17, 2023
e6f6d1c
add case for parsing embedding from auto1111 UI format
piEsposito Jan 18, 2023
bd3b595
fix style after rebase
piEsposito Jan 18, 2023
22abd33
Merge branch 'main' into main
piEsposito Jan 18, 2023
6f9c186
Merge branch 'main' into main
piEsposito Jan 19, 2023
0be8c24
Merge branch 'main' into main
piEsposito Jan 20, 2023
f68a5f6
Merge branch 'main' of github.com:piEsposito/diffusers into piesposit…
EandrewJones Jan 24, 2023
baaf3df
move textual inversion mixin to loaders
EandrewJones Jan 24, 2023
314c1e2
move mixin inheritance to DiffusionPipeline from StableDiffusionPipel…
EandrewJones Jan 24, 2023
719e6a7
update dummy class name
EandrewJones Jan 24, 2023
3790d31
addressed allo comments
EandrewJones Jan 25, 2023
ef8ab03
fix old dangling import
EandrewJones Jan 25, 2023
5939c86
Merge pull request #1 from EandrewJones/main
piEsposito Jan 26, 2023
531d61a
Merge branch 'main' into main
piEsposito Jan 28, 2023
32c86b5
fix style
piEsposito Jan 28, 2023
04c91ba
merge conflicts
patrickvonplaten Mar 23, 2023
23a36ef
proposal
patrickvonplaten Mar 23, 2023
f090898
remove bogus
patrickvonplaten Mar 23, 2023
f5b6ff1
Apply suggestions from code review
patrickvonplaten Mar 28, 2023
8a040e8
finish
patrickvonplaten Mar 28, 2023
835a8d0
make style
patrickvonplaten Mar 28, 2023
08a85dc
up
patrickvonplaten Mar 28, 2023
d172099
fix code quality
piEsposito Mar 29, 2023
991d3d7
fix code quality - again
piEsposito Mar 29, 2023
28c425b
fix code quality - 3
piEsposito Mar 29, 2023
df9f579
fix alt diffusion code quality
piEsposito Mar 29, 2023
e101d9a
Merge branch 'main' into main
piEsposito Mar 29, 2023
9dd0267
fix model editing pipeline
piEsposito Mar 29, 2023
74b1e64
Apply suggestions from code review
patrickvonplaten Mar 30, 2023
b9f53cb
Finish
patrickvonplaten Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .loaders import TextualInversionLoaderMixin
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
Expand Down
295 changes: 284 additions & 11 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,38 @@
# limitations under the License.
import os
from collections import defaultdict
from typing import Callable, Dict, Union
from typing import Callable, Dict, List, Optional, Union

import torch

from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
deprecate,
is_safetensors_available,
is_transformers_available,
logging,
)


if is_safetensors_available():
import safetensors

if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer


logger = logging.get_logger(__name__)


LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"

TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"


class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
Expand Down Expand Up @@ -123,13 +136,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).

</Tip>

<Tip>

Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.

</Tip>
"""

Expand Down Expand Up @@ -292,5 +298,272 @@ def save_function(weights, filename):

# Save the model
save_function(state_dict, os.path.join(save_directory, weight_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")


class TextualInversionLoaderMixin:
r"""
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
"""

def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.

Parameters:
prompt (`str` or list of `str`):
The prompt or prompts to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.

Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt

prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]

if not isinstance(prompt, List):
return prompts[0]

return prompts

def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.

Parameters:
prompt (`str`):
The prompt to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.

Returns:
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
for token in tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f"{token}_{i}"
i += 1

prompt = prompt.replace(token, replacement)

return prompt

def load_textual_inversion(
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
):
r"""
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
`Automatic1111` formats are supported.

<Tip warning={true}>

This function is experimental and might change in the future.

</Tip>

Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:

- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like
`"sd-concepts-library/low-poly-hd-logos-icons"`.
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used in two cases:

- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
name, such as `text_inv.bin`.
- The saved textual inversion file is in the "Automatic1111" form.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.

mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.

<Tip>

It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).

</Tip>
"""
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)

if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True

user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}

# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")

# 2. Load token and embedding correcly from file
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]

if token is not None and loaded_token != token:
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want to do the opposite override? (What comes in the state_dict is what gets added)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I'd say what gets passed has priority! If you do:

load_textual_inversion("./textual_inversion", token="<special-token>")

I think the token should be "<special-token>" no matter what's in the dict - it's similar to how we do from_pretrained(unet=unet) overrides

else:
token = loaded_token

embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)

# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = self.tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1

raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)

is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1

if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying the latest version and I wasn't getting anything related to the embeddings, after changing this I was able to get good results. I'm not sure how this would work with len(embedding.shape) greater than 1 but at least when the shape has only one dimension this seems to fix it.

Suggested change
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
Suggested change
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
embeddings = [embedding] if len(embedding.shape) <= 1 else [embedding[0]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm cannot reproduce this one - my tests are passing just fine on this branch

Copy link

@JarvusChen JarvusChen Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed diffusers/examples/textual_inversion to train my own embedding and got learned_embeds.bin file in the end, then use the example here.

pretrained_path = 'xxx'
embedding_path = 'learned_embeds.bin'
pipe = DiffusionPipeline.from_pretrained(pretrained_path, torch_dtype=torch.float16)
pipe.load_textual_inversion(embedding_path)

It is not working to show any concept from my enbedding token. I have to modify the same with @GuiyeC for the file src/diffusers/loaders.py to make it correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a PR with the fix for this where I try to explain the problem a bit more.


# add tokens and get ids
self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)

# resize token embeddings and set new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
for token_id, embedding in zip(token_ids, embeddings):
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

logger.info("Loaded textual inversion embedding for {token}.")
Loading