Skip to content

Support Kohya-ss style LoRA file format (in a limited capacity) #3437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 2, 2023

Conversation

takuma104
Copy link
Contributor

@takuma104 takuma104 commented May 15, 2023

What's this?

Discussed in #3064. The LoRA file formats currently supported by the A1111 webui are not compatible with the LoRA file formats used in current Diffusers. This LoRA file format is gaining popularity on sites such as CivitAI and some huggingface hubs. These LoRA files have the following features:

  • These files are likely mostly produced by kohya-ss/sd-scripts trainer or its derivatives.
  • An alpha key exists for all keys. This value is used as the scale of the output, calculated as alpha/dim.
  • The current implementation of Diffusers is only for Attention, but this is extended to ClipMLP, Transformer2DModel's proj_in, proj_out, and so on. Keys exist for 1x1 nn.Conv2d, not just nn.Linear. I plan to open a separate PR for this part.

There also exists a script to rewrite the weights of Unet/TextEncoder, but in this PR, we aim to handle it without rewriting the weights of Unet/TextEncoder, by extending the current dynamic LoRA processing of Diffusers.

Workflow:

We are considering the following proposal.
#3064 (comment)

The approach will be to proceed while checking the match of output results step by step with a hook version that was created in advance as a PoC. I have added a work file called tests/test_kohya_loras_scaffold.py, which is the hook version, and plan to place a non-hook version that performs the same generation in src/diffusers as needed. This will mainly involve modifications to src/diffusers/loaders.py.

Todo:

  • Support for Attention only in Unet
  • Support for Attention only in TextEncoder
  • Add .alpha member to LoRALinearLayer

- [ ] Revisit the instantiation of lora parameters for text encoder in the train_dreambooth_lora.py script
- [ ] Support for Attention, Linear, Conv2d in Unet
- [ ] Support for Attention, Linear, Conv2d in TextEncoder
- [ ] Investigate for high memory usage

Generation Comparison:

LoRA file: Light and Shadow

Without Lora Fully Applied LoRA (hook) Partially Applied LoRA (hook) Current this PR
test_orig test_lora_hook test_lora_hook_ test_lora_dev

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 15, 2023

The documentation is not available anymore as the PR was closed or merged.

@takuma104
Copy link
Contributor Author

takuma104 commented May 15, 2023

@sayakpaul I'm trying to convert the text_encoder part, but it seems that the generated images are getting corrupted, and I'm currently investigating the cause. I'm not sure if it's the direct cause, but there's something I can't immediately understand in the current Diffusers implementation, and I would appreciate your advice.

Below is part of the sayakpaul/test-lora-diffusers file found in PR #2918, dumped with this script. It seems that there are unused keys for one SelfAttn class. From reading the code, it seems that only to_k_lora is used for k_proj, and only to_out_lora is used for out_proj, and so on, so it seems that 3/4 of the keys are wasted. Is my understanding correct? Or are all keys being used?

text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_out_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_out_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_q_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_q_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_v_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_v_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_out_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_out_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_q_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_q_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_v_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_v_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_k_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_k_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_out_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_out_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_q_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_q_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_v_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_v_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_k_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_k_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_out_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_out_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_q_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_q_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_v_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_v_lora.up.weight [768, 4]

In the Kohya-ss style LoRA, there are only the following equivalent keys. Initially, I tried to convert this 1:1, but when I put it into pipe.load_lora_weights(), it gives an error. Therefore, I am currently inflating it by four times.

text_encoder.text_model.encoder.layers.0.self_attn.to_k_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.to_k_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.to_out_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.to_out_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.to_q_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.to_q_lora.up.weight [768, 4]
text_encoder.text_model.encoder.layers.0.self_attn.to_v_lora.down.weight [4, 768]
text_encoder.text_model.encoder.layers.0.self_attn.to_v_lora.up.weight [768, 4]

Comment on lines 890 to 891
if any("alpha" in k for k in state_dict.keys()):
state_dict = self._convert_kohya_lora_to_diffusers(state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, this is the primary criterion that is deciding that this is a non-diffusers checkpoint, yeah?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. I'm currently using what seems to be the simplest method for criterion, but it might be better to perform a more precise check using another function or something similar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fairly decent starting point. We can tackle other criteria as more such conditions come up. WDYT?

Comment on lines 55 to 57
def forward(self, x):
scale = self.alpha / self.lora_dim
return self.multiplier * scale * self.lora_up(self.lora_down(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Questions:

  • If we check here

key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)

we can notice that the computation of scale is different. It's a runtime argument, rather. So, if I am reading it correctly, to be able to get to an scale=1 (which is the default in Diffusers LoRA), it's the self.alpha and self.lora_dim that need to be set accordingly.

  • Why is there an additional multipler, though?

Copy link
Contributor Author

@takuma104 takuma104 May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct. The variable name alpha might be confusing, but it originates from the original code. Its value is set during training, equivalent to --network_alpha in this Japanese document. In this PR, I'm thinking it would be good to set the value of self.alpha / self.lora_dim in LoRALinearLayer. self.multiplier corresponds to the scale value in Diffusers, which is set by the user at inference time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm cool! Fair enough. Thanks for explaining!

def forward(self, x):
if len(self.lora_modules) == 0:
return self.orig_forward(x)
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful!

Seems like it should be able to handle multiple LoRAs, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this hook version supports multiple LoRAs. Please refer to the comments in this code for usage. I would also like to see Diffusers support multiple LoRAs. I'm looking forward to future PRs for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. We will definitely keep you posted on this.

@@ -0,0 +1,347 @@
#
#
# TODO: REMOVE THIS FILE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add this approach to Diffusers or not but I would like to thank you for adding it in this PR anyway. I learned TONS. The design, the thought process, the execution -- everything is very clean and readable. Especially the readability part -- I understand it's a complex problem to solve -- yet the code was fairly readable and simple to handle it.

Really, thanks! I learned many things from it.

@sayakpaul
Copy link
Member

#3437 (comment)

@takuma104 I don't fully understand the concern here.

I consolidated how we handle the initialization, serialization, and loading of the LoRA modules related to the text encoder in this Colab Notebook:
https://colab.research.google.com/gist/sayakpaul/300bbeaa748849308d146de97346bef9/scratchpad.ipynb

On the surface, nothing seems off to me, i.e., all the keys are used at that point. Could you elaborate on what I am missing out on?

if "lora_down" in key:
lora_name = key.split(".")[0]
value.size()[0]
lora_name_up = lora_name + ".lora_up.weight"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever and clean!

@sayakpaul
Copy link
Member

sayakpaul commented May 16, 2023

I am also trying to use your conversion method to try to load final state dict into a pipeline using load_lora_weights() using the following code:

Retrieve the checkpoint:

!wget https://civitai.com/api/download/models/15603 -O a1111_lora.safetensors -q

Load it:

import safetensors
from safetensors import torch

a1111_state_dict = safetensors.torch.load_file("a1111_lora.safetensors")

Convert:

unet_state_dict, te_state_dict = {}, {}

for key, value in a1111_state_dict.items():
    if "lora_down" in key:
        lora_name = key.split(".")[0]
        value.size()[0]
        lora_name_up = lora_name + ".lora_up.weight"
        lora_name_alpha = lora_name + ".alpha"
        if lora_name_alpha in a1111_state_dict:
            print(f"Alpha found: {a1111_state_dict[lora_name_alpha].item()}")
            # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim)

        if lora_name.startswith("lora_unet_"):
            diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
            diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
            diffusers_name = diffusers_name.replace("mid.block", "mid_block")
            diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
            diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
            diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
            diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
            diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
            diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
            if "transformer_blocks" in diffusers_name:
                if "attn1" in diffusers_name or "attn2" in diffusers_name:
                    diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
                    diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
                    unet_state_dict[diffusers_name] = value
                    unet_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up]
        elif lora_name.startswith("lora_te_"):
            diffusers_name = key.replace("lora_te_", "").replace("_", ".")
            diffusers_name = diffusers_name.replace("text.model", "text_model")
            diffusers_name = diffusers_name.replace("self.attn", "self_attn")
            diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
            diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
            diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
            diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
            if "self_attn" in diffusers_name:
                prefix = ".".join(
                    diffusers_name.split(".")[:-3]
                )  # e.g.: text_model.encoder.layers.0.self_attn
                suffix = ".".join(diffusers_name.split(".")[-3:])  # e.g.: to_k_lora.down.weight
                for module_name in TEXT_ENCODER_TARGET_MODULES:
                    diffusers_name = f"{prefix}.{module_name}.{suffix}"
                    te_state_dict[diffusers_name] = value
                    te_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up]

Prepare:

unet_state_dict = {f"unet.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
a1111_lora_state_dict_for_diffusers = {**unet_state_dict, **te_state_dict}

Save:

import os

a1111_diffusers_path = "a1111-diffusers-lora"
os.makedirs(a1111_diffusers_path, exist_ok=True)

filename = "pytorch_lora_weights.safetensors"
safetensors.torch.save_file(
    a1111_lora_state_dict_for_diffusers, os.path.join(a1111_diffusers_path, filename), metadata={"format": "pt"}
)

But here, I am meeting with:

RuntimeError: 
            Some tensors share memory, this will lead to duplicate memory on disk and potential differences when 
loading them again: [{'text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_k_lora.down.weight'}, 
{'text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.up.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.up.weight', 
...

But I am able to load the state dict directly like so:

from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.load_lora_weights(a1111_lora_state_dict_for_diffusers)

@sayakpaul
Copy link
Member

Based on the discussions of #3437 (comment), I am able to generate the following:

image

Without LoRA:

image

I don't see anything immediately off.

Here's my Colab Notebook:
https://colab.research.google.com/gist/sayakpaul/f0a1b81c78b205e466161c484ec5ffa5/scratchpad.ipynb

Could you comment?

@takuma104
Copy link
Contributor Author

#3437 (comment)

@takuma104 I don't fully understand the concern here.

I consolidated how we handle the initialization, serialization, and loading of the LoRA modules related to the text encoder in this Colab Notebook: https://colab.research.google.com/gist/sayakpaul/300bbeaa748849308d146de97346bef9/scratchpad.ipynb

On the surface, nothing seems off to me, i.e., all the keys are used at that point. Could you elaborate on what I am missing out on?

Rather than about serialization, I had doubts about the original design of applying LoRA to CLIPAttention. I've rewritten the code a bit.

for name, module in text_encoder.named_modules():
    # if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
    if name.endswith('self_attn'):
        print(name)
        text_lora_attn_procs[name] = LoRAAttnProcessor(
            hidden_size=module.out_proj.out_features, cross_attention_dim=None
        )

self_attn is the CLIPAttention class. This is equivalent to the Attention class in Diffusers. I thought it would be appropriate to generate LoRAAttnProcessor in a 1:1 relationship with this. What do you think?

@takuma104
Copy link
Contributor Author

But here, I am meeting with:

RuntimeError: 
            Some tensors share memory, this will lead to duplicate memory on disk and potential differences when 
loading them again: [{'text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.v_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.down.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.q_proj.to_k_lora.down.weight'}, 
{'text_encoder.text_model.encoder.layers.0.self_attn.out_proj.to_k_lora.up.weight', 
'text_encoder.text_model.encoder.layers.0.self_attn.k_proj.to_k_lora.up.weight', 
...

Ah, this issue seems to be caused by the inflation part, in this for loop.

for module_name in TEXT_ENCODER_TARGET_MODULES:
  diffusers_name = f"{prefix}.{module_name}.{suffix}"
  te_state_dict[diffusers_name] = value
  te_state_dict[diffusers_name.replace(".down.", ".up.")] = a1111_state_dict[lora_name_up]

Using .copy() might work, but I'm not sure if it's the best solution. This is also related to the discussion in the previous comment.

@takuma104
Copy link
Contributor Author

takuma104 commented May 16, 2023

Based on the discussions of #3437 (comment), I am able to generate the following:

It appears that LoRA isn't being applied at all. I've reproduced this in my local environment, which I've tried to keep as clean as possible. The code itself isn't the problem, but there seems to be an issue with the environment. Once I switched to using the latest dev version of Diffusers from the main branch, LoRA was applied (although partially). I think this also needs to be investigated, but for now, here's the revised version. However, there are still a few issues.

  • In the latest dev version of Diffusers, the behavior of .from_ckpt seems odd, and the model loaded with .from_pretrained is also replaced with a different model. This is a serious bug for me as I often use this feature. So I just opened an Issue StableDiffusionPipeline.from_ckpt is not working on dev version #3450.
  • As it stands with the code you provided, there's an issue with applying LoRA to the text_encoder, which results in a distorted image being generated. I believe this issue is fixed in this commit. However, I question the necessity of this commit itself. I think the original code should work fine.
  • For reasons that are unclear, it currently uses a lot of VRAM. I couldn't run it on a T4 instance on Colab. It was able to generate without any issues on an A100 instance. I think this issue needs to be investigated separately.

takuma104 added 3 commits May 17, 2023 02:24
While the terrible images are no longer produced,
the results do not match those from the hook ver.
This may be due to not setting the network_alpha value.
Comment on lines 946 to 966
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
if name in attn_processors:
module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
module.old_forward = module.forward

# Monkey-patch.
module.forward = new_forward
def new_forward(self, x):
return self.old_forward(x) + self.lora_layer(x)

# Monkey-patch.
module.forward = new_forward.__get__(module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I get a brief explanation summarizing the changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change, which I believe should not be necessary in principle, is one that I would like to revert once I understand the underlying cause and confirm that it's not a problem.

When I investigated the cause of the severe output distortion when enabling LoRA for the text_encoder, I found that the output was still distorted even when I removed the + lora_layer(x) part. Therefore, I concluded that this monkey-patch itself was the cause. As a countermeasure, I tried making it an instance method, and for some reason it worked, although I'm not entirely sure why it was fixed.

I compared the original code equivalent with the method of making it an instance method in simplified test code, but both of these tests pass without any problems.
https://gist.github.com/takuma104/1263383cdab8f54bb14f389facdbe960

I suspect that it might be a memory-related issue with Python or PyTorch garbage collection, but I'm not sure yet. I'm thinking of adjusting the test a bit closer to the current situation and investigating.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into it.

Comment on lines +492 to +494
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for me!

@sayakpaul
Copy link
Member

sayakpaul commented May 17, 2023

Thanks @takuma104!

Regarding this Colab, I see that not all the attention layers in the CLIP model are being tied with LoRA. Specifically, 288 layers from the self-attention blocks are missing LoRA application. Could you help me understand why that is the case?

In this Colab, we see the images are corrupted. Could the above be related to that?

Agree that the RAM increase should be investigated in a separate PR.

@takuma104
Copy link
Contributor Author

@sayakpaul

Regarding this Colab, I see that not all the attention layers in the CLIP model are being tied with LoRA. Specifically, 288 layers from the self-attention blocks are missing LoRA application. Could you help me understand why that is the case?

The SD1.5 TextEncoder (CLIPEncoder) is configured with 12 CLIPEncoderLayers according to the config.json, and each CLIPEncoderLayer has a single CLIPAttention in its .self_attn member, resulting in a total of 12 CLIPAttention instances.

When trying to apply LoRA to all of these CLIPAttention instances, since one LoRAAttnProcessor should be enough for one Attention, that's what I did in my Colab code. Since LoRAAttnProcessor has four LoRALinearLayers, and each LoRALinearLayer has two nn.Linear instances, multiplying these results in 12*4*2=96 instances (=keys).

On the other hand, the current implementation of Diffusers creates a LoRAAttnProcessor for each of .q_proj, .k_proj, .v_proj, .out_proj members of CLIPAttention, which quadruples the number to 12*4*4*2=384 instances (=keys). This difference accounts for the 288 you mentioned.

In this Colab, we see the images are corrupted. Could the above be related to that?

Yes, Applying this commit should eliminate this image corrupted.

@takuma104
Copy link
Contributor Author

@patrickvonplaten I have commented on each issue.

#3437 (comment)
#3437 (comment)

@takuma104 takuma104 changed the title Support Kohya-ss style LoRA file format Support Kohya-ss style LoRA file format (in a limited capacity) May 29, 2023
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def test_text_encoder_lora_monkey_patch(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@@ -893,6 +899,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
else:
state_dict = pretrained_model_name_or_path_or_dict

# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for me!

@@ -847,9 +847,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change here? Is this necessary? This makes this script be a bit out of sync with train_text_to_image_lora.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #3437 (comment). Note that this is only for the text encoder training part. For train_text_to_image_lora.py, we don't essentially train the text encoder anyway. So, I think it's okay.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks more or less good to go for me. Could we just try to fix this one: https://github.com/huggingface/diffusers/pull/3437/files#r1210711878

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks more or less good to go for me. Could we just try to fix this one: https://github.com/huggingface/diffusers/pull/3437/files#r1210711878

@sayakpaul
Copy link
Member

@sayakpaul
Copy link
Member

@patrickvonplaten would be cool to get a final review here and I think then we can merge.

@takuma104 any final things you think we are missing in this PR?

@takuma104
Copy link
Contributor Author

@sayakpaul I'm fine with everything except the issue #3621. Shall we make #3621 a separate PR?

@sayakpaul
Copy link
Member

@sayakpaul I'm fine with everything except the issue #3621. Shall we make #3621 a separate PR?

Yes, that would make sense.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR, this all looks good to me :-)

Think we can merge this one! Agree that we should handle the problem of recursive LoRA loading in a new PR! Let's discuss possible solutions here.

@sayakpaul sayakpaul merged commit 8e552bb into huggingface:main Jun 2, 2023
@takuma104
Copy link
Contributor Author

@sayakpaul @patrickvonplaten Thanks a lot!

@patrickvonplaten
Copy link
Contributor

Thank you!

@sarmientoj24
Copy link

@takuma104 @patrickvonplaten
is there a documentation on

  • how to use this? loading and unloading the LoRA?
  • how to use it on prompts?

@sayakpaul
Copy link
Member

@sarmientoj24
Copy link

@sayakpaul thanks for the link. amazing! just one thing, how do you unload it?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ingface#3437)

* add _convert_kohya_lora_to_diffusers

* make style

* add scaffold

* match result: unet attention only

* fix monkey-patch for text_encoder

* with CLIPAttention

While the terrible images are no longer produced,
the results do not match those from the hook ver.
This may be due to not setting the network_alpha value.

* add to support network_alpha

* generate diff image

* fix monkey-patch for text_encoder

* add test_text_encoder_lora_monkey_patch()

* verify that it's okay to release the attn_procs

* fix closure version

* add comment

* Revert "fix monkey-patch for text_encoder"

This reverts commit bb9c61e.

* Fix to reuse utility functions

* make LoRAAttnProcessor targets to self_attn

* fix LoRAAttnProcessor target

* make style

* fix split key

* Update src/diffusers/loaders.py

* remove TEXT_ENCODER_TARGET_MODULES loop

* add print memory usage

* remove test_kohya_loras_scaffold.py

* add: doc on LoRA civitai

* remove print statement and refactor in the doc.

* fix state_dict test for kohya-ss style lora

* Apply suggestions from code review

Co-authored-by: Takuma Mori <[email protected]>

---------

Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ingface#3437)

* add _convert_kohya_lora_to_diffusers

* make style

* add scaffold

* match result: unet attention only

* fix monkey-patch for text_encoder

* with CLIPAttention

While the terrible images are no longer produced,
the results do not match those from the hook ver.
This may be due to not setting the network_alpha value.

* add to support network_alpha

* generate diff image

* fix monkey-patch for text_encoder

* add test_text_encoder_lora_monkey_patch()

* verify that it's okay to release the attn_procs

* fix closure version

* add comment

* Revert "fix monkey-patch for text_encoder"

This reverts commit bb9c61e.

* Fix to reuse utility functions

* make LoRAAttnProcessor targets to self_attn

* fix LoRAAttnProcessor target

* make style

* fix split key

* Update src/diffusers/loaders.py

* remove TEXT_ENCODER_TARGET_MODULES loop

* add print memory usage

* remove test_kohya_loras_scaffold.py

* add: doc on LoRA civitai

* remove print statement and refactor in the doc.

* fix state_dict test for kohya-ss style lora

* Apply suggestions from code review

Co-authored-by: Takuma Mori <[email protected]>

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants