-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] Enabling limited LoRA support for text encoder #2918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
src/diffusers/loaders.py
Outdated
@@ -288,7 +288,7 @@ def save_function(weights, filename): | |||
model_to_save = AttnProcsLayers(self.attn_processors) | |||
|
|||
# Save the model | |||
state_dict = model_to_save.state_dict() | |||
state_dict = {"unet": model_to_save.state_dict()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would break a bit with the existing format no? E.g. already trained LoRAs have a different serialization format at the moment I think
src/diffusers/loaders.py
Outdated
# Load the layers corresponding to UNet. | ||
if state_dict.get(self.unet_name, None) is not None: | ||
logger.info(f"Loading {self.unet_name}.") | ||
self.unet.load_attn_procs(state_dict[self.unet_name]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would only work with a new format. Do you think we could maybe leave the old format as is and add some code that extracts the unet part out of it?
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
Love the PR, super cool work @sayakpaul & thanks for iterating so much here. RE:
text_encoder_name = None and then in the StableDiffusionPipeline do: text_encoder_name = "text_encoder" but totally fine for me to just directly do: text_encoder_name = "text_encoder" in the LoraMixin and delete here: #2918 (comment) (as you've done) My last comment would be to make sure the existing serialization format, e.g.: #2918 (comment) is backwards compatible More specifically at the moment we have the following state_dict structure: state_dict.keys():
We have two options now: state_dict.keys():
Advantage: Both unet-only lora and unet+ text encoder lora have same serialization format b) We add a new format (as you've done), but then we need to make sure that the old format can still be loaded here: https://github.com/huggingface/diffusers/pull/2918/files#r1154436980 so in case neither "text_encoder" nor "unet" is in the state_dict we should somehow verify whether the old format can still be loaded. Similarly the old loading function: Both a) and b) are ok for me, but it would be important that all of the following works: from diffusers import StableDiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)
pipeline.load_lora_weights("<lora_unet_and_text_encoder>")
pipeline.load_lora_weights("<lora_unet_only>")
pipeline.unet.load_attn_processors("<lora_unet_only>") So that we have 100% backward comp. Note if we go for b) we should in some sense see the old format as "deprecated" and always save as Hope this makes sense - great job here. Awesome to see this complicate feature being close to the finish line |
@patrickvonplaten with the latest changes, all the following scenarios work: from diffusers import StableDiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)
# Remote also works (new format that has both `unet` and `text_encoder`).
pipeline.load_lora_weights("sayakpaul/test-lora-diffusers")
# Legacy format.
pipeline.load_lora_weights("patrickvonplaten/lora_dreambooth_dog_example")
# or
pipeline.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example") Here is my test notebook. However, the current serialization format fails with This works: import safetensors
safetensors.torch.save_file(
pipeline.unet.state_dict(), "unet.safetensors", metadata={"format": "pt"}
) This doesn't work: safetensors.torch.save_file(
{
"unet": unet_lora_layers.state_dict(),
"text_encoder": text_encoder_lora_layers.state_dict(),
},
"lora.safetensors",
metadata={"format": "pt"},
) in <module> │
│ │
│ ❱ 1 safetensors.torch.save_file( │
│ 2 │ { │
│ 3 │ │ "unet": unet_lora_layers.state_dict(), │
│ 4 │ │ "text_encoder": text_encoder_lora_layers.state_dict(), │
│ │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:71 in save_file │
│ │
│ 68 │ save(tensors, "model.safetensors") │
│ 69 │ ``` │
│ 70 │ """ │
│ ❱ 71 │ serialize_file(_flatten(tensors), filename, metadata=metadata) │
│ 72 │
│ 73 │
│ 74 def load_file(filename: str, device="cpu") -> Dict[str, torch.Tensor]: │
│ │
│ /Users/sayakpaul/.local/bin/.virtualenvs/diffusers-dev/lib/python3.8/site-packages/safetensors/t │
│ orch.py:221 in _flatten │
│ │
│ 218 │ ptrs = defaultdict(set) │
│ 219 │ for k, v in tensors.items(): │
│ 220 │ │ if not isinstance(v, torch.Tensor): │
│ ❱ 221 │ │ │ raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received │
│ 222 │ │ │
│ 223 │ │ if v.layout == torch.strided: │
│ 224 │ │ │ ptrs[v.data_ptr()].add(k) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Key `unet` is invalid, expected torch.Tensor but received <class 'dict'> I think having the new state dict to have two root keys (like the current one) namely Also, FWIW, that with the current design, it's also possible to leverage |
Perhaps unrelated, but is multiple lora loading on the menu? vis-a-vis |
I see! How about we just add another prefix It's then be very simply to figure out if it's old or new format: if not all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name)):
old_format = True Would this work for you? |
# Load the layers corresponding to UNet. | ||
if all(key.startswith(self.unet_name) for key in keys): | ||
logger.info(f"Loading {self.unet_name}.") | ||
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! This is good to merge for me :-)
For follow-up PRs:
1.)
diffusers/src/diffusers/loaders.py
Line 210 in 6e8e1ed
state_dict = pretrained_model_name_or_path_or_dict |
unet.
prefix2.) We can now adapt the LoRA training script to also allow to train the text encoder
3.) We can add the loader to more pipelines than just Stable Diffusion (essentially all pipelines that can use checkpoints for which LoRA can be trained img2img, ...)
@pcuenca @williamberman @patil-suraj could one of you maybe also take a quick look?
src/diffusers/loaders.py
Outdated
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() | ||
): | ||
self.unet.load_attn_procs(state_dict) | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice warning, I'd maybe just use the deprecate
function here:
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) |
deprecate
throws a FutureWarning
which is better IMO compared to logger.warning
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool! See if the latest updates are good?
src/diffusers/loaders.py
Outdated
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them" | ||
" in a dictionary and then create a new dictionary like the following:" | ||
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(For the future):
I think we can also allow loading A1111 format here in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure. And maybe civitai as well.
Could you help me with a few checkpoints that I could investigate?
return old_forward(x) + lora_layer(x) | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me given the circumstances. We simply can't fiddle too much with the text encoder as it's not part of diffusers
@patrickvonplaten, thanks a lot for all the guidance and also for always encouraging friendly constructive discussions.
This is already supported (as in being introduced in this PR): diffusers/src/diffusers/loaders.py Lines 737 to 748 in 6bf6eef
Expect a PR next week :) I will also work on adding multiple LoRAs (will probably be limited in some sense); stay tuned 😉 |
Lora's which are trained by automatic111 etc would be in old format for long time. Can there be a script "convert_old_lora_to_new_format.py" This can help in fixing lots of potential issues in future. |
@adhikjoshi One option is to update the LoRA conversion script to use the new LoRA loading API from this PR (instead of the custom weight-updation logic used) - https://github.com/huggingface/diffusers/blob/main/scripts/convert_lora_safetensor_to_diffusers.py That script currently reads the auto1111 format and generates a diffusers-format model. It outputs the entire SD model, instead of just the LoRA. So we could modify it to generate only a LoRA file (maybe via a CLI argument). I believe this was also discussed in #2829 |
@cmdr2 #2866 (comment) might be relevant for the LoRA safetensors part. Let's discuss in that thread. |
" in a dictionary and then create a new dictionary like the following:" | ||
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." | ||
) | ||
deprecation_message = "You have saved the LoRA weights using the old format. This will be" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Yeah here I meant that this line: diffusers/src/diffusers/loaders.py Line 220 in 6e8e1ed
unet. , but we can probs better figure this out in a follow-up PR :-)
|
Ah, I see what you are saying. Yeah, probably better done in a follow-up PR. |
* add: first draft for a better LoRA enabler. * make fix-copies. * feat: backward compatibility. * add: entry to the docs. * add: tests. * fix: docs. * fix: norm group test for UNet3D. * feat: add support for flat dicts. * add depcrcation message instead of warning.
* add: first draft for a better LoRA enabler. * make fix-copies. * feat: backward compatibility. * add: entry to the docs. * add: tests. * fix: docs. * fix: norm group test for UNet3D. * feat: add support for flat dicts. * add depcrcation message instead of warning.
@patrickvonplaten @sayakpaul
Is it also same for this PR? To train with text encoder, we need higher ram! Also here it's mentioned that,
lora implementation of cloneofsimo offers text encoder training but 6 or 8 gb is enough to run the model. Why this gap? |
Good point. I think it should be possible with 8-bit Adam, but I didn't test it yet. |
It doesn't work with 8 bit adam. |
Is there a tutorial to use this for precise CLIP on stylistic images such as hand sketches? |
* add: first draft for a better LoRA enabler. * make fix-copies. * feat: backward compatibility. * add: entry to the docs. * add: tests. * fix: docs. * fix: norm group test for UNet3D. * feat: add support for flat dicts. * add depcrcation message instead of warning.
* add: first draft for a better LoRA enabler. * make fix-copies. * feat: backward compatibility. * add: entry to the docs. * add: tests. * fix: docs. * fix: norm group test for UNet3D. * feat: add support for flat dicts. * add depcrcation message instead of warning.
Builds on top of #2882 (I had to close that because the conflicts were nasty).
Example usage is as follows.
Initialization
Users will do this manually from their training scripts following this. I think this is fine.
Next step would be training. Easiest part I think.
Serialization
Loading into a pipeline
@patrickvonplaten from #2882 (review):
Could you please elaborate more about it? Did you mean how it's done in here?
TODOs
I suggest we update the
train_dreambooth_lora.py
example in a follow-up PR.