Skip to content

[LoRA] Enabling limited LoRA support for text encoder #2918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 12, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 31, 2023

Builds on top of #2882 (I had to close that because the conflicts were nasty).

Example usage is as follows.

Initialization

Users will do this manually from their training scripts following this. I think this is fine.

from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel

def get_text_encoder():
    return CLIPTextModel.from_pretrained(
        "runwayml/stable-diffusion-v1-5", subfolder="text_encoder"
    )

def get_unet():
    return UNet2DConditionModel.from_pretrained(
        "runwayml/stable-diffusion-v1-5", subfolder="unet"
    )

text_encoder = get_text_encoder()
unet = get_unet()
# UNet LoRA layers. 
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor

lora_attn_procs = {}
for name in unet.attn_processors.keys():
    cross_attention_dim = (
        None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    )
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    lora_attn_procs[name] = LoRAAttnProcessor(
        hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
    )
    
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
# Text encoder LoRA layers.
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES

text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
    if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
        text_lora_attn_procs[name] = LoRAAttnProcessor(
            hidden_size=module.out_features, cross_attention_dim=None
        )

text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)

Next step would be training. Easiest part I think.

Serialization

LoraLoaderMixin.save_lora_weights(
    save_directory=".",
    unet_lora_layers=lora_layers,
    text_encoder_lora_layers=text_encoder_lora_layers
)

Loading into a pipeline

from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)

# Local works.
pipeline.load_lora_weights(".")

# Remote also works. 
pipeline.load_lora_weights("sayakpaul/test-lora-diffusers")

@patrickvonplaten from #2882 (review):

And those are then overwritten in the StableDiffusionPipeline class

Could you please elaborate more about it? Did you mean how it's done in here?

TODOs

  • Tests
  • Documentation

I suggest we update the train_dreambooth_lora.py example in a follow-up PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -288,7 +288,7 @@ def save_function(weights, filename):
model_to_save = AttnProcsLayers(self.attn_processors)

# Save the model
state_dict = model_to_save.state_dict()
state_dict = {"unet": model_to_save.state_dict()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would break a bit with the existing format no? E.g. already trained LoRAs have a different serialization format at the moment I think

# Load the layers corresponding to UNet.
if state_dict.get(self.unet_name, None) is not None:
logger.info(f"Loading {self.unet_name}.")
self.unet.load_attn_procs(state_dict[self.unet_name])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would only work with a new format. Do you think we could maybe leave the old format as is and add some code that extracts the unet part out of it?

@patrickvonplaten
Copy link
Contributor

Love the PR, super cool work @sayakpaul & thanks for iterating so much here.

RE:

And those are then overwritten in the StableDiffusionPipeline class
I meant to here:
https://github.com/huggingface/diffusers/pull/2918/files#r1154437571
do:

text_encoder_name = None

and then in the StableDiffusionPipeline do:

text_encoder_name = "text_encoder"

but totally fine for me to just directly do:

text_encoder_name = "text_encoder"

in the LoraMixin and delete here: #2918 (comment) (as you've done)

My last comment would be to make sure the existing serialization format, e.g.: #2918 (comment) is backwards compatible

More specifically at the moment we have the following state_dict structure:

state_dict.keys():
dict_keys(['down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight', 
...
'mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_out_lora.up.weight'])

We have two options now:
a) We keep the same format and just append the text encoder weights on the same root level

state_dict.keys():
dict_keys(['down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight', 
...
'mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_out_lora.up.weight', # now unet ends
'model.block.0.attention.query.weight # now text encodre starts
...
])

Advantage: Both unet-only lora and unet+ text encoder lora have same serialization format
Disadavantage: We'd need some extra code to "extract" all unet layers from the flat state dict (e.g. by assuming that up_blocks, down_blocks, mid_block prefix corresponds to unet

b) We add a new format (as you've done), but then we need to make sure that the old format can still be loaded here: https://github.com/huggingface/diffusers/pull/2918/files#r1154436980 so in case neither "text_encoder" nor "unet" is in the state_dict we should somehow verify whether the old format can still be loaded. Similarly the old loading function: self.unet.load_attn_processors should also work with checkpoints that have a `"unet": ... serialization format.

Both a) and b) are ok for me, but it would be important that all of the following works:

from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)

pipeline.load_lora_weights("<lora_unet_and_text_encoder>")
pipeline.load_lora_weights("<lora_unet_only>")
pipeline.unet.load_attn_processors("<lora_unet_only>")

So that we have 100% backward comp. Note if we go for b) we should in some sense see the old format as "deprecated" and always save as {"unet": ...} and then probs throw a warning if the old format is loaded with a hint on how to convert it to the new format.

Hope this makes sense - great job here. Awesome to see this complicate feature being close to the finish line

@sayakpaul sayakpaul marked this pull request as ready for review April 3, 2023 06:18
@sayakpaul
Copy link
Member Author

sayakpaul commented Apr 3, 2023

@patrickvonplaten with the latest changes, all the following scenarios work:

from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)

# Remote also works (new format that has both `unet` and `text_encoder`). 
pipeline.load_lora_weights("sayakpaul/test-lora-diffusers")

# Legacy format.
pipeline.load_lora_weights("patrickvonplaten/lora_dreambooth_dog_example")
# or
pipeline.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")

Here is my test notebook.

However, the current serialization format fails with safetensors.

This works:

import safetensors

safetensors.torch.save_file(
    pipeline.unet.state_dict(), "unet.safetensors", metadata={"format": "pt"}
)

This doesn't work:

safetensors.torch.save_file(
    {
        "unet": unet_lora_layers.state_dict(),
        "text_encoder": text_encoder_lora_layers.state_dict(),
    },
    "lora.safetensors",
    metadata={"format": "pt"},
)
in <module>                                                                                      │
│                                                                                                  │
│ ❱ 1 safetensors.torch.save_file(                                                                 │
│   2 │   {                                                                                        │
│   3 │   │   "unet": unet_lora_layers.state_dict(),                                               │
│   4 │   │   "text_encoder": text_encoder_lora_layers.state_dict(),                               │
│                                                                                                  │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:71 in save_file                                                                          │
│                                                                                                  │
│    68 │   save(tensors, "model.safetensors")                                                     │
│    69 │   ```
│    70 │   """
│ ❱  71 │   serialize_file(_flatten(tensors), filename, metadata=metadata)                         │
│    72                                                                                            │
│    73                                                                                            │
│    74 def load_file(filename: str, device="cpu") -> Dict[str, torch.Tensor]:                     │
│                                                                                                  │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:221 in _flatten                                                                          │
│                                                                                                  │
│   218 │   ptrs = defaultdict(set)                                                                │
│   219 │   for k, v in tensors.items():                                                           │
│   220 │   │   if not isinstance(v, torch.Tensor):                                                │
│ ❱ 221 │   │   │   raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received    │
│   222 │   │                                                                                      │
│   223 │   │   if v.layout == torch.strided:                                                      │
│   224 │   │   │   ptrs[v.data_ptr()].add(k)                                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Key `unet` is invalid, expected torch.Tensor but received <class 'dict'>

I think having the new state dict to have two root keys (like the current one) namely unet and text_encoder is a cleaner design than having a state dict that has all the module keys. So, maybe this should be solved at safetensors level.

Also, FWIW, that with the current design, it's also possible to leverage LoraLoaderMixin.save_lora_weights to serialize one of unet_lora_layers and text_encoder_lora_layers as well.

@NicholasKao1029
Copy link

Perhaps unrelated, but is multiple lora loading on the menu? vis-a-vis

#2613

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 4, 2023

@patrickvonplaten with the latest changes, all the following scenarios work:

from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)

# Remote also works (new format that has both `unet` and `text_encoder`). 
pipeline.load_lora_weights("sayakpaul/test-lora-diffusers")

# Legacy format.
pipeline.load_lora_weights("patrickvonplaten/lora_dreambooth_dog_example")
# or
pipeline.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")

Here is my test notebook.

However, the current serialization format fails with safetensors.

This works:

import safetensors

safetensors.torch.save_file(
    pipeline.unet.state_dict(), "unet.safetensors", metadata={"format": "pt"}
)

This doesn't work:

safetensors.torch.save_file(
    {
        "unet": unet_lora_layers.state_dict(),
        "text_encoder": text_encoder_lora_layers.state_dict(),
    },
    "lora.safetensors",
    metadata={"format": "pt"},
)
in <module>                                                                                      │
│                                                                                                  │
│ ❱ 1 safetensors.torch.save_file(                                                                 │
│   2 │   {                                                                                        │
│   3 │   │   "unet": unet_lora_layers.state_dict(),                                               │
│   4 │   │   "text_encoder": text_encoder_lora_layers.state_dict(),                               │
│                                                                                                  │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:71 in save_file                                                                          │
│                                                                                                  │
│    68 │   save(tensors, "model.safetensors")                                                     │
│    69 │   ```
│    70 │   """
│ ❱  71 │   serialize_file(_flatten(tensors), filename, metadata=metadata)                         │
│    72                                                                                            │
│    73                                                                                            │
│    74 def load_file(filename: str, device="cpu") -> Dict[str, torch.Tensor]:                     │
│                                                                                                  │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:221 in _flatten                                                                          │
│                                                                                                  │
│   218 │   ptrs = defaultdict(set)                                                                │
│   219 │   for k, v in tensors.items():                                                           │
│   220 │   │   if not isinstance(v, torch.Tensor):                                                │
│ ❱ 221 │   │   │   raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received    │
│   222 │   │                                                                                      │
│   223 │   │   if v.layout == torch.strided:                                                      │
│   224 │   │   │   ptrs[v.data_ptr()].add(k)                                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Key `unet` is invalid, expected torch.Tensor but received <class 'dict'>

I think having the new state dict to have two root keys (like the current one) namely unet and text_encoder is a cleaner design than having a state dict that has all the module keys. So, maybe this should be solved at safetensors level.

Also, FWIW, that with the current design, it's also possible to leverage LoraLoaderMixin.save_lora_weights to serialize one of unet_lora_layers and text_encoder_lora_layers as well.

I see! How about we just add another prefix "unet." and "text_encoder." to every weight in the module. We need safetensors to work and to me it's equally nicer (maybe even nicer) to have a flat module. Wdyt?

It's then be very simply to figure out if it's old or new format:

if not all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name)):
    old_format = True

Would this work for you?

# Load the layers corresponding to UNet.
if all(key.startswith(self.unet_name) for key in keys):
logger.info(f"Loading {self.unet_name}.")
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! This is good to merge for me :-)

For follow-up PRs:
1.)

state_dict = pretrained_model_name_or_path_or_dict
should in my opinon also support checkpoints that have the unet. prefix
2.) We can now adapt the LoRA training script to also allow to train the text encoder
3.) We can add the loader to more pipelines than just Stable Diffusion (essentially all pipelines that can use checkpoints for which LoRA can be trained img2img, ...)

@pcuenca @williamberman @patil-suraj could one of you maybe also take a quick look?

key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
):
self.unet.load_attn_procs(state_dict)
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice warning, I'd maybe just use the deprecate function here:

deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)

deprecate throws a FutureWarning which is better IMO compared to logger.warning here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool! See if the latest updates are good?

" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(For the future):
I think we can also allow loading A1111 format here in the future

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure. And maybe civitai as well.

Could you help me with a few checkpoints that I could investigate?

return old_forward(x) + lora_layer(x)

# Monkey-patch.
module.forward = new_forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me given the circumstances. We simply can't fiddle too much with the text encoder as it's not part of diffusers

@sayakpaul
Copy link
Member Author

@patrickvonplaten, thanks a lot for all the guidance and also for always encouraging friendly constructive discussions.

should in my opinon also support checkpoints that have the unet. prefix

This is already supported (as in being introduced in this PR):

# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
):
self.unet.load_attn_procs(state_dict)
logger.warning(
"You have saved the LoRA weights using the old format. This will be"
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
)

2.) We can now adapt the LoRA training script to also allow to train the text encoder

Expect a PR next week :) I will also work on adding multiple LoRAs (will probably be limited in some sense); stay tuned 😉

@adhikjoshi
Copy link

Very nice! This is good to merge for me :-)

For follow-up PRs:

1.)

state_dict = pretrained_model_name_or_path_or_dict
should in my opinon also support checkpoints that have the unet. prefix

2.) We can now adapt the LoRA training script to also allow to train the text encoder

3.) We can add the loader to more pipelines than just Stable Diffusion (essentially all pipelines that can use checkpoints for which LoRA can be trained img2img, ...)

@pcuenca @williamberman @patil-suraj could one of you maybe also take a quick look?

Lora's which are trained by automatic111 etc would be in old format for long time.

Can there be a script "convert_old_lora_to_new_format.py"

This can help in fixing lots of potential issues in future.

@cmdr2
Copy link
Contributor

cmdr2 commented Apr 7, 2023

Lora's which are trained by automatic111 etc would be in old format for long time.

Can there be a script "convert_old_lora_to_new_format.py"

This can help in fixing lots of potential issues in future.

@adhikjoshi One option is to update the LoRA conversion script to use the new LoRA loading API from this PR (instead of the custom weight-updation logic used) - https://github.com/huggingface/diffusers/blob/main/scripts/convert_lora_safetensor_to_diffusers.py

That script currently reads the auto1111 format and generates a diffusers-format model. It outputs the entire SD model, instead of just the LoRA. So we could modify it to generate only a LoRA file (maybe via a CLI argument). I believe this was also discussed in #2829

@sayakpaul
Copy link
Member Author

@cmdr2 #2866 (comment) might be relevant for the LoRA safetensors part.

Let's discuss in that thread.

" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
)
deprecation_message = "You have saved the LoRA weights using the old format. This will be"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten, thanks a lot for all the guidance and also for always encouraging friendly constructive discussions.

should in my opinon also support checkpoints that have the unet. prefix

This is already supported (as in being introduced in this PR):

# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
):
self.unet.load_attn_procs(state_dict)
logger.warning(
"You have saved the LoRA weights using the old format. This will be"
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
)

2.) We can now adapt the LoRA training script to also allow to train the text encoder

Expect a PR next week :) I will also work on adding multiple LoRAs (will probably be limited in some sense); stay tuned wink

Yeah here I meant that this line:

attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
might not work with checkpoints that are prefixed with unet., but we can probs better figure this out in a follow-up PR :-)

@sayakpaul
Copy link
Member Author

might not work with checkpoints that are prefixed with unet., but we can probs better figure this out in a follow-up PR :-)

Ah, I see what you are saying. Yeah, probably better done in a follow-up PR.

@sayakpaul sayakpaul merged commit a89a14f into main Apr 12, 2023
@sayakpaul sayakpaul deleted the feat/loraloadermixin branch April 12, 2023 02:59
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add: first draft for a better LoRA enabler.

* make fix-copies.

* feat: backward compatibility.

* add: entry to the docs.

* add: tests.

* fix: docs.

* fix: norm group test for UNet3D.

* feat: add support for flat dicts.

* add depcrcation message instead of warning.
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
* add: first draft for a better LoRA enabler.

* make fix-copies.

* feat: backward compatibility.

* add: entry to the docs.

* add: tests.

* fix: docs.

* fix: norm group test for UNet3D.

* feat: add support for flat dicts.

* add depcrcation message instead of warning.
@pure-rgb
Copy link

@patrickvonplaten @sayakpaul
In this doc, it is mentioned that

Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.

Is it also same for this PR? To train with text encoder, we need higher ram!

Also here it's mentioned that,

cloneofsimo was the first to try out LoRA training for Stable Diffusion in the popular lora GitHub repository.

lora implementation of cloneofsimo offers text encoder training but 6 or 8 gb is enough to run the model. Why this gap?

@sayakpaul
Copy link
Member Author

Good point. I think it should be possible with 8-bit Adam, but I didn't test it yet.

@pure-rgb
Copy link

It doesn't work with 8 bit adam.
In here, do you set gradient checkpointing for all model components, including text encoder? Because, in cloneofsimo code he did. Not sure if it's the main reason or what so.

@shkr
Copy link

shkr commented Aug 11, 2023

Is there a tutorial to use this for precise CLIP on stylistic images such as hand sketches?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add: first draft for a better LoRA enabler.

* make fix-copies.

* feat: backward compatibility.

* add: entry to the docs.

* add: tests.

* fix: docs.

* fix: norm group test for UNet3D.

* feat: add support for flat dicts.

* add depcrcation message instead of warning.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: first draft for a better LoRA enabler.

* make fix-copies.

* feat: backward compatibility.

* add: entry to the docs.

* add: tests.

* fix: docs.

* fix: norm group test for UNet3D.

* feat: add support for flat dicts.

* add depcrcation message instead of warning.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants