-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add function to remove monkey-patch for text encoder LoRA #3649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c30daf3
203a70e
356a46a
beecb02
5d7939d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -299,6 +299,46 @@ def test_text_encoder_lora_monkey_patch(self): | |
| outputs_without_lora, outputs_with_lora | ||
| ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" | ||
|
|
||
| def test_text_encoder_lora_remove_monkey_patch(self): | ||
| pipeline_components, _ = self.get_dummy_components() | ||
| pipe = StableDiffusionPipeline(**pipeline_components) | ||
|
|
||
| dummy_tokens = self.get_dummy_tokens() | ||
|
|
||
| # inference without lora | ||
| outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] | ||
| assert outputs_without_lora.shape == (1, 77, 32) | ||
|
|
||
| # create lora_attn_procs with randn up.weights | ||
| text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) | ||
| set_lora_up_weights(text_attn_procs, randn_weight=True) | ||
|
|
||
| # monkey patch | ||
| pipe._modify_text_encoder(text_attn_procs) | ||
|
|
||
| # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| del text_attn_procs | ||
| gc.collect() | ||
|
|
||
| # inference with lora | ||
| outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] | ||
| assert outputs_with_lora.shape == (1, 77, 32) | ||
|
|
||
| assert not torch.allclose( | ||
| outputs_without_lora, outputs_with_lora | ||
| ), "lora outputs should be different to without lora outputs" | ||
|
|
||
| # remove monkey patch | ||
| pipe._remove_text_encoder_monkey_patch() | ||
|
|
||
| # inference with removed lora | ||
| outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] | ||
| assert outputs_without_lora_removed.shape == (1, 77, 32) | ||
|
|
||
| assert torch.allclose( | ||
| outputs_without_lora, outputs_without_lora_removed | ||
| ), "remove lora monkey patch should restore the original outputs" | ||
|
|
||
| def create_lora_weight_file(self, tmpdirname): | ||
|
||
| _, lora_components = self.get_dummy_components() | ||
| LoraLoaderMixin.save_lora_weights( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be
TEXT_ENCODER_ATTN_MODULE? I see this doesn't get reflected in here too:diffusers/src/diffusers/loaders.py
Line 968 in 523a50a
I am asking this because when initializing the LoRA layers for the text encoder, we use
TEXT_ENCODER_ATTN_MODULE:diffusers/examples/dreambooth/train_dreambooth_lora.py
Line 864 in 523a50a
Doesn't this create a disparity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is correct because the target of the monkey-patch is not the
CLIPAttentionitself, but itsq/k/v/out_projattributes. However, as you pointed out, I thought it would be easier to understand if it was aligned withtrain_dreambooth_lora.py, so I tried refactoring to removeTEXT_ENCODER_TARGET_MODULES. 356a46a WDYT?