Skip to content

add load textual inversion embeddings to stable diffusion #2009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
6012e93
add load textual inversion embeddings draft
piEsposito Jan 16, 2023
a3a800b
Merge branch 'main' into main
piEsposito Jan 16, 2023
d4642c7
fix quality
piEsposito Jan 16, 2023
ca6d38d
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 16, 2023
c5ffdc3
fix typo
piEsposito Jan 16, 2023
32391af
Merge branch 'main' into main
piEsposito Jan 16, 2023
525428d
make fix copies
piEsposito Jan 16, 2023
912c7c3
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 16, 2023
15206c3
Merge branch 'huggingface:main' into main
piEsposito Jan 17, 2023
fdec2d0
move to textual inversion mixin
piEsposito Jan 17, 2023
e01a3f8
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Jan 17, 2023
5ec8fea
make it accept from sd-concept library
piEsposito Jan 17, 2023
5d58240
accept list of paths to embeddings
piEsposito Jan 17, 2023
530a208
fix styling of stable diffusion pipeline
piEsposito Jan 17, 2023
8e50514
add dummy TextualInversionMixin
piEsposito Jan 17, 2023
b730987
add docstring to textualinversionmixin
piEsposito Jan 17, 2023
65b76f8
add load textual inversion embeddings draft
piEsposito Jan 16, 2023
66a7489
fix quality
piEsposito Jan 16, 2023
82dff21
fix typo
piEsposito Jan 16, 2023
bf0424b
make fix copies
piEsposito Jan 16, 2023
22e4751
move to textual inversion mixin
piEsposito Jan 17, 2023
f25292c
make it accept from sd-concept library
piEsposito Jan 17, 2023
f231854
accept list of paths to embeddings
piEsposito Jan 17, 2023
ced8e14
fix styling of stable diffusion pipeline
piEsposito Jan 17, 2023
5d2ef24
add dummy TextualInversionMixin
piEsposito Jan 17, 2023
e9284a4
add docstring to textualinversionmixin
piEsposito Jan 17, 2023
e6f6d1c
add case for parsing embedding from auto1111 UI format
piEsposito Jan 18, 2023
bd3b595
fix style after rebase
piEsposito Jan 18, 2023
22abd33
Merge branch 'main' into main
piEsposito Jan 18, 2023
6f9c186
Merge branch 'main' into main
piEsposito Jan 19, 2023
0be8c24
Merge branch 'main' into main
piEsposito Jan 20, 2023
f68a5f6
Merge branch 'main' of github.com:piEsposito/diffusers into piesposit…
EandrewJones Jan 24, 2023
baaf3df
move textual inversion mixin to loaders
EandrewJones Jan 24, 2023
314c1e2
move mixin inheritance to DiffusionPipeline from StableDiffusionPipel…
EandrewJones Jan 24, 2023
719e6a7
update dummy class name
EandrewJones Jan 24, 2023
3790d31
addressed allo comments
EandrewJones Jan 25, 2023
ef8ab03
fix old dangling import
EandrewJones Jan 25, 2023
5939c86
Merge pull request #1 from EandrewJones/main
piEsposito Jan 26, 2023
531d61a
Merge branch 'main' into main
piEsposito Jan 28, 2023
32c86b5
fix style
piEsposito Jan 28, 2023
04c91ba
merge conflicts
patrickvonplaten Mar 23, 2023
23a36ef
proposal
patrickvonplaten Mar 23, 2023
f090898
remove bogus
patrickvonplaten Mar 23, 2023
f5b6ff1
Apply suggestions from code review
patrickvonplaten Mar 28, 2023
8a040e8
finish
patrickvonplaten Mar 28, 2023
835a8d0
make style
patrickvonplaten Mar 28, 2023
08a85dc
up
patrickvonplaten Mar 28, 2023
d172099
fix code quality
piEsposito Mar 29, 2023
991d3d7
fix code quality - again
piEsposito Mar 29, 2023
28c425b
fix code quality - 3
piEsposito Mar 29, 2023
df9f579
fix alt diffusion code quality
piEsposito Mar 29, 2023
e101d9a
Merge branch 'main' into main
piEsposito Mar 29, 2023
9dd0267
fix model editing pipeline
piEsposito Mar 29, 2023
74b1e64
Apply suggestions from code review
patrickvonplaten Mar 30, 2023
b9f53cb
Finish
patrickvonplaten Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,34 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def load_textual_inversion_embeddings(self, embeddings):
r"""
Loads textual inversion embeddings. Receives a dictionary with the following keys:
- `token`: name of the token to be added to the tokenizers' vocabulary
- `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix

Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text
encoder's embedding matrix.
"""
for token, embedding_path in embeddings.items():
# check if token in tokenizer vocab
# if yes, raise exception
if token in self.tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name."
)

# load embedding from embedding path then convert it to self.text_encoder's device and dtype
embedding = torch.load(embedding_path)
embedding = embedding.to(self.text_encoder.device)
embedding = embedding.to(self.text_encoder.dtype)

self.tokenizer.add_tokens([token])

token_id = self.tokenizer.convert_tokens_to_ids(token)
self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1)
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,34 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def load_textual_inversion_embeddings(self, embeddings):
r"""
Loads textual inversion embeddings. Receives a dictionary with the following keys:
- `token`: name of the token to be added to the tokenizers' vocabulary
- `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix

Iters through the dictionary and adds the token to the tokenizer's vocabulary and the embedding to the text
encoder's embedding matrix.
"""
for token, embedding_path in embeddings.items():
# check if token in tokenizer vocab
# if yes, raise exception
if token in self.tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name."
)

# load embedding from embedding path then convert it to self.text_encoder's device and dtype
embedding = torch.load(embedding_path)
embedding = embedding.to(self.text_encoder.device)
embedding = embedding.to(self.text_encoder.dtype)

self.tokenizer.add_tokens([token])

token_id = self.tokenizer.convert_tokens_to_ids(token)
self.text_encoder.resize_token_embeddings(len(self.tokenizer) + 1)
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
Expand Down