-
Notifications
You must be signed in to change notification settings - Fork 6k
Support Kohya-ss style LoRA file format (in a limited capacity) #3437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@sayakpaul I'm trying to convert the text_encoder part, but it seems that the generated images are getting corrupted, and I'm currently investigating the cause. I'm not sure if it's the direct cause, but there's something I can't immediately understand in the current Diffusers implementation, and I would appreciate your advice. Below is part of the
In the Kohya-ss style LoRA, there are only the following equivalent keys. Initially, I tried to convert this 1:1, but when I put it into
|
src/diffusers/loaders.py
Outdated
if any("alpha" in k for k in state_dict.keys()): | ||
state_dict = self._convert_kohya_lora_to_diffusers(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding, this is the primary criterion that is deciding that this is a non-diffusers checkpoint, yeah?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are correct. I'm currently using what seems to be the simplest method for criterion, but it might be better to perform a more precise check using another function or something similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fairly decent starting point. We can tackle other criteria as more such conditions come up. WDYT?
tests/test_kohya_loras_scaffold.py
Outdated
def forward(self, x): | ||
scale = self.alpha / self.lora_dim | ||
return self.multiplier * scale * self.lora_up(self.lora_down(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Questions:
- If we check here
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) |
we can notice that the computation of scale
is different. It's a runtime argument, rather. So, if I am reading it correctly, to be able to get to an scale=1
(which is the default in Diffusers LoRA), it's the self.alpha
and self.lora_dim
that need to be set accordingly.
- Why is there an additional
multipler
, though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's correct. The variable name alpha
might be confusing, but it originates from the original code. Its value is set during training, equivalent to --network_alpha
in this Japanese document. In this PR, I'm thinking it would be good to set the value of self.alpha / self.lora_dim
in LoRALinearLayer
. self.multiplier
corresponds to the scale
value in Diffusers, which is set by the user at inference time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm cool! Fair enough. Thanks for explaining!
tests/test_kohya_loras_scaffold.py
Outdated
def forward(self, x): | ||
if len(self.lora_modules) == 0: | ||
return self.orig_forward(x) | ||
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful!
Seems like it should be able to handle multiple LoRAs, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this hook version supports multiple LoRAs. Please refer to the comments in this code for usage. I would also like to see Diffusers support multiple LoRAs. I'm looking forward to future PRs for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. We will definitely keep you posted on this.
tests/test_kohya_loras_scaffold.py
Outdated
@@ -0,0 +1,347 @@ | |||
# | |||
# | |||
# TODO: REMOVE THIS FILE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We add this approach to Diffusers or not but I would like to thank you for adding it in this PR anyway. I learned TONS. The design, the thought process, the execution -- everything is very clean and readable. Especially the readability part -- I understand it's a complex problem to solve -- yet the code was fairly readable and simple to handle it.
Really, thanks! I learned many things from it.
@takuma104 I don't fully understand the concern here. I consolidated how we handle the initialization, serialization, and loading of the LoRA modules related to the text encoder in this Colab Notebook: On the surface, nothing seems off to me, i.e., all the keys are used at that point. Could you elaborate on what I am missing out on? |
if "lora_down" in key: | ||
lora_name = key.split(".")[0] | ||
value.size()[0] | ||
lora_name_up = lora_name + ".lora_up.weight" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever and clean!
I am also trying to use your conversion method to try to load final state dict into a pipeline using Retrieve the checkpoint:
Load it: import safetensors
from safetensors import torch
a1111_state_dict = safetensors.torch.load_file("a1111_lora.safetensors") Convert: unet_state_dict, te_state_dict = {}, {}
for key, value in a1111_state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
value.size()[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in a1111_state_dict:
print(f"Alpha found: {a1111_state_dict[lora_name_alpha].item()}")
# print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim)
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
prefix = ".".join(
diffusers_name.split(".")[:-3]
) # e.g.: text_model.encoder.layers.0.self_attn
suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight
for module_name in TEXT_ENCODER_TARGET_MODULES:
diffusers_name = f"{prefix}.{module_name}.{suffix}"
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up] Prepare: unet_state_dict = {f"unet.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
a1111_lora_state_dict_for_diffusers = {**unet_state_dict, **te_state_dict} Save: import os
a1111_diffusers_path = "a1111-diffusers-lora"
os.makedirs(a1111_diffusers_path, exist_ok=True)
filename = "pytorch_lora_weights.safetensors"
safetensors.torch.save_file(
a1111_lora_state_dict_for_diffusers, os.path.join(a1111_diffusers_path, filename), metadata={"format": "pt"}
) But here, I am meeting with:
But I am able to load the state dict directly like so: from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.load_lora_weights(a1111_lora_state_dict_for_diffusers) |
Based on the discussions of #3437 (comment), I am able to generate the following: Without LoRA: I don't see anything immediately off. Here's my Colab Notebook: Could you comment? |
Rather than about serialization, I had doubts about the original design of applying LoRA to for name, module in text_encoder.named_modules():
# if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith('self_attn'):
print(name)
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
|
Ah, this issue seems to be caused by the inflation part, in this for loop. for module_name in TEXT_ENCODER_TARGET_MODULES:
diffusers_name = f"{prefix}.{module_name}.{suffix}"
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up] Using .copy() might work, but I'm not sure if it's the best solution. This is also related to the discussion in the previous comment. |
It appears that LoRA isn't being applied at all. I've reproduced this in my local environment, which I've tried to keep as clean as possible. The code itself isn't the problem, but there seems to be an issue with the environment. Once I switched to using the latest dev version of Diffusers from the main branch, LoRA was applied (although partially). I think this also needs to be investigated, but for now, here's the revised version. However, there are still a few issues.
|
While the terrible images are no longer produced, the results do not match those from the hook ver. This may be due to not setting the network_alpha value.
src/diffusers/loaders.py
Outdated
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
old_forward = module.forward | ||
|
||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
if name in attn_processors: | ||
module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
module.old_forward = module.forward | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward | ||
def new_forward(self, x): | ||
return self.old_forward(x) + self.lora_layer(x) | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward.__get__(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could I get a brief explanation summarizing the changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change, which I believe should not be necessary in principle, is one that I would like to revert once I understand the underlying cause and confirm that it's not a problem.
When I investigated the cause of the severe output distortion when enabling LoRA for the text_encoder, I found that the output was still distorted even when I removed the + lora_layer(x)
part. Therefore, I concluded that this monkey-patch itself was the cause. As a countermeasure, I tried making it an instance method, and for some reason it worked, although I'm not entirely sure why it was fixed.
I compared the original code equivalent with the method of making it an instance method in simplified test code, but both of these tests pass without any problems.
https://gist.github.com/takuma104/1263383cdab8f54bb14f389facdbe960
I suspect that it might be a memory-related issue with Python or PyTorch garbage collection, but I'm not sure yet. I'm thinking of adjusting the test a bit closer to the current situation and investigating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into it.
if self.network_alpha is not None: | ||
up_hidden_states *= self.network_alpha / self.rank | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay for me!
Thanks @takuma104! Regarding this Colab, I see that not all the attention layers in the CLIP model are being tied with LoRA. Specifically, 288 layers from the self-attention blocks are missing LoRA application. Could you help me understand why that is the case? In this Colab, we see the images are corrupted. Could the above be related to that? Agree that the RAM increase should be investigated in a separate PR. |
The SD1.5 TextEncoder ( When trying to apply LoRA to all of these On the other hand, the current implementation of Diffusers creates a
Yes, Applying this commit should eliminate this image corrupted. |
@patrickvonplaten I have commented on each issue. |
prepared_inputs["input_ids"] = inputs | ||
return prepared_inputs | ||
|
||
def test_text_encoder_lora_monkey_patch(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
@@ -893,6 +899,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |||
else: | |||
state_dict = pretrained_model_name_or_path_or_dict | |||
|
|||
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs | |||
network_alpha = None | |||
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works for me!
@@ -847,9 +847,9 @@ def main(args): | |||
if args.train_text_encoder: | |||
text_lora_attn_procs = {} | |||
for name, module in text_encoder.named_modules(): | |||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): | |||
if name.endswith(TEXT_ENCODER_ATTN_MODULE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change here? Is this necessary? This makes this script be a bit out of sync with train_text_to_image_lora.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to #3437 (comment). Note that this is only for the text encoder training part. For train_text_to_image_lora.py
, we don't essentially train the text encoder anyway. So, I think it's okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks more or less good to go for me. Could we just try to fix this one: https://github.com/huggingface/diffusers/pull/3437/files#r1210711878
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks more or less good to go for me. Could we just try to fix this one: https://github.com/huggingface/diffusers/pull/3437/files#r1210711878
Co-authored-by: Takuma Mori <[email protected]>
@patrickvonplaten see: https://github.com/huggingface/diffusers/pull/3437/files#r1210995688 |
@patrickvonplaten would be cool to get a final review here and I think then we can merge. @takuma104 any final things you think we are missing in this PR? |
@sayakpaul I'm fine with everything except the issue #3621. Shall we make #3621 a separate PR? |
Yes, that would make sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool PR, this all looks good to me :-)
Think we can merge this one! Agree that we should handle the problem of recursive LoRA loading in a new PR! Let's discuss possible solutions here.
@sayakpaul @patrickvonplaten Thanks a lot! |
Thank you! |
@takuma104 @patrickvonplaten
|
@sayakpaul thanks for the link. amazing! just one thing, how do you unload it? |
…ingface#3437) * add _convert_kohya_lora_to_diffusers * make style * add scaffold * match result: unet attention only * fix monkey-patch for text_encoder * with CLIPAttention While the terrible images are no longer produced, the results do not match those from the hook ver. This may be due to not setting the network_alpha value. * add to support network_alpha * generate diff image * fix monkey-patch for text_encoder * add test_text_encoder_lora_monkey_patch() * verify that it's okay to release the attn_procs * fix closure version * add comment * Revert "fix monkey-patch for text_encoder" This reverts commit bb9c61e. * Fix to reuse utility functions * make LoRAAttnProcessor targets to self_attn * fix LoRAAttnProcessor target * make style * fix split key * Update src/diffusers/loaders.py * remove TEXT_ENCODER_TARGET_MODULES loop * add print memory usage * remove test_kohya_loras_scaffold.py * add: doc on LoRA civitai * remove print statement and refactor in the doc. * fix state_dict test for kohya-ss style lora * Apply suggestions from code review Co-authored-by: Takuma Mori <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
…ingface#3437) * add _convert_kohya_lora_to_diffusers * make style * add scaffold * match result: unet attention only * fix monkey-patch for text_encoder * with CLIPAttention While the terrible images are no longer produced, the results do not match those from the hook ver. This may be due to not setting the network_alpha value. * add to support network_alpha * generate diff image * fix monkey-patch for text_encoder * add test_text_encoder_lora_monkey_patch() * verify that it's okay to release the attn_procs * fix closure version * add comment * Revert "fix monkey-patch for text_encoder" This reverts commit bb9c61e. * Fix to reuse utility functions * make LoRAAttnProcessor targets to self_attn * fix LoRAAttnProcessor target * make style * fix split key * Update src/diffusers/loaders.py * remove TEXT_ENCODER_TARGET_MODULES loop * add print memory usage * remove test_kohya_loras_scaffold.py * add: doc on LoRA civitai * remove print statement and refactor in the doc. * fix state_dict test for kohya-ss style lora * Apply suggestions from code review Co-authored-by: Takuma Mori <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
What's this?
Discussed in #3064. The LoRA file formats currently supported by the A1111 webui are not compatible with the LoRA file formats used in current Diffusers. This LoRA file format is gaining popularity on sites such as CivitAI and some huggingface hubs. These LoRA files have the following features:
The current implementation of Diffusers is only for Attention, but this is extended to ClipMLP, Transformer2DModel's proj_in, proj_out, and so on. Keys exist for 1x1 nn.Conv2d, not just nn.Linear.I plan to open a separate PR for this part.There also exists a script to rewrite the weights of Unet/TextEncoder, but in this PR, we aim to handle it without rewriting the weights of Unet/TextEncoder, by extending the current dynamic LoRA processing of Diffusers.
Workflow:
We are considering the following proposal.
#3064 (comment)
The approach will be to proceed while checking the match of output results step by step with a hook version that was created in advance as a PoC. I have added a work file called
tests/test_kohya_loras_scaffold.py
, which is the hook version, and plan to place a non-hook version that performs the same generation insrc/diffusers
as needed. This will mainly involve modifications tosrc/diffusers/loaders.py
.Todo:
- [ ] Revisit the instantiation of lora parameters for text encoder in the train_dreambooth_lora.py script- [ ] Support for Attention, Linear, Conv2d in Unet- [ ] Support for Attention, Linear, Conv2d in TextEncoder- [ ] Investigate for high memory usageGeneration Comparison:
LoRA file: Light and Shadow