Skip to content

Loading .safetensors Lora #3064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
adhikjoshi opened this issue Apr 12, 2023 · 85 comments
Closed

Loading .safetensors Lora #3064

adhikjoshi opened this issue Apr 12, 2023 · 85 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@adhikjoshi
Copy link

Describe the bug

I have downloaded lora from civitai which is in .safetensor format.

When i load it using below code,

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("lora.safetensors")

It throws error : KeyError: 'to_k_lora.down.weight'

File "/workspace/server/tasks.py", line 346, in txt2img
self.pipe.unet.load_attn_procs(embd, use_safetensors=True)
File "/opt/conda/envs/ldm/lib/python3.8/site-packages/diffusers/loaders.py", line 224, in load_attn_procs
rank = value_dict["to_k_lora.down.weight"].shape[0]
KeyError: 'to_k_lora.down.weight'

Reproduction

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("lora.safetensors")

Logs

No response

System Info

Diffusers Version: 0.15.0.dev0

@adhikjoshi adhikjoshi added the bug Something isn't working label Apr 12, 2023
@patrickvonplaten
Copy link
Contributor

Hey @adhikjoshi,

Thanks for the issue we should indeed try to support also A1111 loading of LoRA tensors soon. cc @sayakpaul here

@alejobrainz
Copy link

alejobrainz commented Apr 12, 2023

Kohya-ss/sd-scripts has a nice mechanism for it, but it broke with 0.15, but you can for sure load A1111 LoRA Tensors with the function below on 0.14.0:

def apply_lora(pipe, lora_path, weight:float = 1.0):
    from safetensors.torch import load_file
    from sd-scripts.networks.lora import create_network_from_weights
    import torch
    
    vae = pipe.vae
    text_encoder = pipe.text_encoder
    unet = pipe.unet

    sd = load_file(lora_path)
    lora_network, sd = create_network_from_weights(weight, None, vae, text_encoder, unet, sd)
    lora_network.apply_to(text_encoder, unet)
    lora_network.load_state_dict(sd)
    lora_network.to("cuda", dtype=torch.float16)

but as of 0.15 it fails:

assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
AssertionError: duplicated lora name: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q

@adhikjoshi
Copy link
Author

Kohya-ss/sd-scripts has a nice mechanism for it, but it broke with 0.15, but you can for sure load A1111 LoRA Tensors with the function below on 0.14.0:


def apply_lora(pipe, lora_path, weight:float = 1.0):

    from safetensors.torch import load_file

    from sd-scripts.networks.lora import create_network_from_weights

    import torch

    

    vae = pipe.vae

    text_encoder = pipe.text_encoder

    unet = pipe.unet



    sd = load_file(lora_path)

    lora_network, sd = create_network_from_weights(weight, None, vae, text_encoder, unet, sd)

    lora_network.apply_to(text_encoder, unet)

    lora_network.load_state_dict(sd)

    lora_network.to("cuda", dtype=torch.float16)

but as of 0.15 it fails:


assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"

AssertionError: duplicated lora name: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q



CC @haofanwang @sayakpaul

@sayakpaul
Copy link
Member

Can someone provide LoRA file in the A1111 format? Providing as many relevant details associated to the file as possible would be great too.

@adhikjoshi
Copy link
Author

Can someone provide LoRA file in the A1111 format? Providing as many relevant details associated to the file as possible would be great too.

I have downloaded offset noise trained lora and uploaded its .safetensor on huggingface

https://huggingface.co/adhikjoshi/epi_noiseoffset

@alejobrainz
Copy link

@sayakpaul here you go. This Lora was trained using ss-kohya's scripts and works fine in A1111. I can load it on diffusers 0.14.0 with the snippet above using the lora.py from sd-scripts:

caAos-000001.zip

Thanks,

Alejandro.

@sayakpaul
Copy link
Member

Cc: @patrickvonplaten ^

@adhikjoshi
Copy link
Author

Here is function i made from convert_lora_safetensor_to_diffusers.py to load lora on inference time.

import torch
from safetensors.torch import load_file

def load_lora_weights(pipeline, checkpoint_path):
    # load base model
    pipeline.to("cuda")
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    alpha = 0.75
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path, device="cuda")
    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

Can use it like,

lora_model = lora_models + "/" + opt.lora + ".safetensors"
self.pipe = load_lora_weights(self.pipe, lora_model)

@sayakpaul @patrickvonplaten

@sayakpaul
Copy link
Member

Thanks! Do you have the checkpoints with which we could test this?

@adhikjoshi
Copy link
Author

Can someone provide LoRA file in the A1111 format? Providing as many relevant details associated to the file as possible would be great too.

I have downloaded offset noise trained lora and uploaded its .safetensor on huggingface

https://huggingface.co/adhikjoshi/epi_noiseoffset

This uploaded safetensor lora and others work well

@pdoane
Copy link
Contributor

pdoane commented Apr 17, 2023

Thanks @adhikjoshi! Getting a lot further with your function but the output is not matching what I would expect. As a first guess, I would think this is the alpha handling as that is hard-coded to 0.75 but the LoRA's I'm using have .alpha keys in them.

@pdoane
Copy link
Contributor

pdoane commented Apr 18, 2023

I updated the function from @adhikjoshi to use the .alpha elements and also added a multiplier that can be used to weight the LoRA overall. Tested this on 4 random LoRAs I downloaded from CivitAI and it matches the output from Automatic1111:

def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path, device=device)

    updates = defaultdict(dict)
    for key, value in state_dict.items():
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        layer, elem = key.split('.', 1)
        updates[layer][elem] = value

    # directly update weight in diffusers model
    for layer, elems in updates.items():

        if "text" in layer:
            layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        # get elements for this layer
        weight_up = elems['lora_up.weight'].to(dtype)
        weight_down = elems['lora_down.weight'].to(dtype)
        alpha = elems['alpha']
        if alpha:
            alpha = alpha.item() / weight_up.shape[1]
        else:
            alpha = 1.0

        # update weight
        if len(weight_up.shape) == 4:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
        else:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)

    return pipeline

Example usage:

pipe = load_lora_weights(pipe, lora_path, 1.0, 'cuda', torch.float32)

@alejobrainz
Copy link

I tested on several custom-created Loras, works great! Excellent work, @pdoane, thanks for sharing.

Quick question for the group. is there a way to quickly unload a Lora weight from a loaded pipeline? I want to maintain it in memory, and simply assign/remove Lora Embeddings on the fly after each inference. Any pointers are appreciated.

Thanks again!

Alejandro

@pdoane
Copy link
Contributor

pdoane commented Apr 19, 2023

There are two options I can think of:

  • Layer updating is a linear operation so it can be reversed by passing in a negative multiplier. Because of floating-point rounding, there could be a gradual drift over time.

  • You can make a copy of the tensor for each modified layer and restore it later. As the LoRAs are small relative to the model, this is probably preferred (and I expect faster).

@alejobrainz
Copy link

I'll try approach #2

@alejobrainz
Copy link

Ugly, but worked for me. Tested making 600 inferences switching between 12 Loras safetensors 50 times on diffusers 0.15.1:

from safetensors.torch import load_file
from collections import defaultdict
from diffusers.loaders import LoraLoaderMixin
import torch 

current_pipeline = None
original_weights = {}

def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    global current_pipeline, original_weights
    
    if (pipeline != current_pipeline):
        backup = True
        current_pipeline = pipeline
        original_weights = {}    
    else:
        backup = False
    
    # load base model
    pipeline.to(device)
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path, device=device)

    updates = defaultdict(dict)
    for key, value in state_dict.items():
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        layer, elem = key.split('.', 1)
        updates[layer][elem] = value

    index = 0
    # directly update weight in diffusers model
    for layer, elems in updates.items():
        index += 1

        if "text" in layer:
            layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        # get elements for this layer
        weight_up = elems['lora_up.weight'].to(dtype)
        weight_down = elems['lora_down.weight'].to(dtype)
        alpha = elems['alpha']
        if alpha:
            alpha = alpha.item() / weight_up.shape[1]
        else:
            alpha = 1.0
        
        if (backup):
            original_weights[index] = curr_layer.weight.data.clone().detach()
        else:
            curr_layer.weight.data = original_weights[index].clone().detach()

        # update weight
        if len(weight_up.shape) == 4:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
        else:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)

    return pipeline

LoraLoaderMixin.load_lora_weights = load_lora_weights

@sayakpaul
Copy link
Member

@pdoane thanks so much for your inputs and investigations!

Do you mind sharing the pipe and lora_path you tested #3064 (comment) with?

@pdoane
Copy link
Contributor

pdoane commented Apr 19, 2023

@sayakpaul - followed up in e-mail.

@sayakpaul
Copy link
Member

Thanks. However, I think having an end-to-end open example here would help the community a great deal to understand the nuances of the interoperability in a better manner.

@sayakpaul
Copy link
Member

@pdoane come to think of it, would you be interested to improve our LoRA functionality to operate with the A1111 format as well?

@patrickvonplaten recently incorporated similar support for our textual inversion scripts: https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion

@pdoane
Copy link
Contributor

pdoane commented Apr 19, 2023

My assumption is this is just the first step to getting something more official - would be glad to help!

I have some API questions about it:

  1. Do you want support in the main API or as an example/converter?
  2. Assuming it is in the main API, the existing method of unet.load_attn_procs() is not the right place as the text encoder needs modification as well.
  3. Weight restoration is an important use case too, probably an optional dictionary parameter to store weight information and another method to re-apply.

In terms of format details:

  • The existing LoRA support has a different assumption for key names. I'm not sure what format is being assumed currently and also not sure how it should be reconciled with this approach. The A1111 code suggests that the layer name convention being used in the above scripts is "diffusers" and not "compvis". Are there LoRA files that use compvis layer names?
  • MultiheadAttention support is missing. Should be easy to add but I wanted to find an example first.
  • There are a variety of other formats too (e.g. LyCORIS) and I don't know how common those are.

@alexblattner
Copy link

alexblattner commented Apr 20, 2023

@alejobrainz how do you use your code for it to work with a prompt in the same way as A111? I put this as prompt:

prompt="art by <lora:mngstle:1>"
n_prompt="(nsfw), out of frame, multiple people, petite, loli, side view, profile, lowres, (bad anatomy, bad hands:1.1), text, (tattoo), error, missing fingers, extra digit, fewer digits, cropped, worst quality, (((many people))), low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name,weird colors, (cartoon, 3d, bad art, poorly drawn, close up, blurry:1.5), (disfigured, deformed, extra limbs:1.5)"

but it ignored the lora instructions completely.

@alejobrainz
Copy link

For prompt weighting you can use compel. It's great and easy to use. Just be sure to check out the syntax at https://github.com/damian0815/compel/blob/main/Reference.md

@alejobrainz
Copy link

Also, be mindful that the lora is embedded using the script. you only need the keyword your Lora uses within the prompt.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jun 16, 2023
@patrickvonplaten
Copy link
Contributor

I think we can close this one no @sayakpaul ?

@sayakpaul
Copy link
Member

sayakpaul commented Jun 16, 2023

Yeah I think so. We have recently introduced support for loading A1111 formatted LoRAs (thanks to @takuma104).

We will continue to reiterate on top of it.

@notdanilo
Copy link

How can I update just the multiplier in curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) without rebuilding the network?

@sayakpaul
Copy link
Member

sayakpaul commented Jun 29, 2023

How would you like adjust the multiplier?

@notdanilo
Copy link

How would you like adjust the multiplier?

The same value is used on different layers, right? Say that I have initialized it with 1.0, then I just want to update it like multiplier = 0.75, updating everywhere it's referenced in the network.

Is it something possible to achieve?

@sayakpaul
Copy link
Member

I think you would need to fetch the multiplier appropriately from the corresponding state dict accordingly for that. Since the multiplier only concerns the weight modification part, I don't think reinitialization would be required here.

@sonaterai
Copy link

sonaterai commented Jul 7, 2023

Hello, I'm using version v0.18 of Diffuser, and I managed to load a LoRa .safetensors file from CivitAI. However, I don't notice any changes even when I run lora:test:1.
Do you have any idea what could be causing this issue? Thank you in advance for your help.

@adhikjoshi
Copy link
Author

With SDXL old lora won't be working.

Any workarounds?

@sayakpaul
Copy link
Member

@sonaterai I think there are some extra keys and corresponding weight parameters for which we don't have support yet. See #3087. It also links to a couple of other similar threads.

With SDXL old lora won't be working.

It's not supposed as the corresponding UNet is different as far as I understand.

@sanbuphy
Copy link

sanbuphy commented Jul 8, 2023

@sonaterai I think there are some extra keys and corresponding weight parameters for which we don't have support yet. See #3087. It also links to a couple of other similar threads.

With SDXL old lora won't be working.

It's not supposed as the corresponding UNet is different as far as I understand.

If i train the Lora base on SDXL by myself, is that anything way to load the new Lora?

@sayakpaul
Copy link
Member

sayakpaul commented Jul 8, 2023

If i train the Lora base on SDXL by myself, is that anything way to load the new Lora?

You should be able to. If not, please open a new thread with a reproducible set of instructions.

@kkwhale7
Copy link

kkwhale7 commented Sep 1, 2023

Here is function i made from convert_lora_safetensor_to_diffusers.py to load lora on inference time.

import torch
from safetensors.torch import load_file

def load_lora_weights(pipeline, checkpoint_path):
    # load base model
    pipeline.to("cuda")
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    alpha = 0.75
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path, device="cuda")
    visited = []

    # directly update weight in diffusers model
    for key in state_dict:
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        # as we have set the alpha beforehand, so just skip
        if ".alpha" in key or key in visited:
            continue

        if "text" in key:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        pair_keys = []
        if "lora_down" in key:
            pair_keys.append(key.replace("lora_down", "lora_up"))
            pair_keys.append(key)
        else:
            pair_keys.append(key)
            pair_keys.append(key.replace("lora_up", "lora_down"))

        # update weight
        if len(state_dict[pair_keys[0]].shape) == 4:
            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
        else:
            weight_up = state_dict[pair_keys[0]].to(torch.float32)
            weight_down = state_dict[pair_keys[1]].to(torch.float32)
            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)

        # update visited list
        for item in pair_keys:
            visited.append(item)

    return pipeline

Can use it like,

lora_model = lora_models + "/" + opt.lora + ".safetensors"
self.pipe = load_lora_weights(self.pipe, lora_model)

@sayakpaul @patrickvonplaten

in my case in diffusers0.14.0, it gets wrong :
Traceback (most recent call last):
File "demo_txt2img.py", line 88, in
demo.loadEngines(args.engine_dir, args.onnx_dir, args.onnx_opset,
File "/workspace/demo/Diffusion/stable_diffusion_pipeline.py", line 288, in loadEngines
model = obj.get_model()
File "/workspace/demo/Diffusion/models.py", line 300, in get_model
return load_lora_weights(basic_unet, self.lora, "unet")
File "/workspace/demo/Diffusion/safetensors2bin.py", line 49, in load_lora_weights
temp_name += "_" + layer_infos.pop(0)
IndexError: pop from empty list

can u help me ? thx

@sayakpaul
Copy link
Member

Please use the load_lora_weights() API for this (docs). We don't maintain that script, so, sadly cannot offer any support.

@kkwhale7
Copy link

kkwhale7 commented Sep 1, 2023

Please use the load_lora_weights() API for this (docs). We don't maintain that script, so, sadly cannot offer any support.

however, i want to load lora for my unet and clip, not for the pipeline(Because I reused TensorRT's code)

@kkwhale7
Copy link

kkwhale7 commented Sep 1, 2023

@sayakpaul

@SlZeroth
Copy link
Contributor

SlZeroth commented Oct 2, 2023

I tried the load_lora_weights() but inference results are confuse. load_lora_weights load the text encoder params in the lora?

@kkwhale7
Copy link

kkwhale7 commented Oct 2, 2023

I tried the load_lora_weights() but inference results are confuse. load_lora_weights load the text encoder params in the lora?

yes , both clip and unet

@alirezaomneky
Copy link

I am having a similar issue. When I run pipe.load_lora_weights("lora_model_id"), where lora_model_id is the path to my model, I simply get this error:

KeyError: 'lora.down.weight'

Does anyone know how to fix this issue?

@sayakpaul
Copy link
Member

Can you open a new issue with a reproducible code snippet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging a pull request may close this issue.