-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix monkey-patch for text_encoder LoRA #3490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
fb708fb
6e8f3ab
8511755
81915f4
88db546
1da772b
8c0926c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -946,11 +946,15 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | |
| lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
| old_forward = module.forward | ||
|
|
||
| def new_forward(x): | ||
| return old_forward(x) + lora_layer(x) | ||
| # create a new scope that locks in the old_forward, lora_layer value for each new_forward function | ||
| def make_new_forward(old_forward, lora_layer): | ||
| def new_forward(x): | ||
| return old_forward(x) + lora_layer(x) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small comment: if you load lora with e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your suggestion! Indeed, it seems that might be the case. I wonder if it might be better to create a mechanism to remove the moneky-patch. @sayakpaul WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Best is indeed to do something else than monkey-patching, but a flag like
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're currently discussing this internally, and will keep y'all posted.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meanwhile,
@rvorias could you elaborate what you mean here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have This is useful in contexts where you want to load arbitrary lora weights on the fly in a long-running SD inference engine. If you add the flag+condition you can still have the new lora weights to load, but you don't override the forward again and again.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you be willing to open a PR for this? We're more than happy to help you with that :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return new_forward | ||
|
|
||
| # Monkey-patch. | ||
| module.forward = new_forward | ||
| module.forward = make_new_forward(old_forward, lora_layer) | ||
|
|
||
| def _get_lora_layer_attribute(self, name: str) -> str: | ||
| if "q_proj" in name: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import gc | ||||
| import os | ||||
| import tempfile | ||||
| import unittest | ||||
|
|
@@ -212,3 +213,75 @@ def test_lora_save_load_legacy(self): | |||
|
|
||||
| # Outputs shouldn't match. | ||||
| self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) | ||||
|
|
||||
| # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb | ||||
| def get_dummy_tokens(self): | ||||
| max_seq_length = 77 | ||||
|
|
||||
| inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) | ||||
|
|
||||
| prepared_inputs = {} | ||||
| prepared_inputs["input_ids"] = inputs | ||||
| return prepared_inputs | ||||
|
|
||||
| def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): | ||||
|
||||
| def create_text_encoder_lora_layers(text_encoder: nn.Module): |
- Also, from our discussions in Support Kohya-ss style LoRA file format (in a limited capacity) #3437 (particularly this comment), it seems we also need to change the target modules for which we're applying LoRA, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we reuse this function (making changes to it is completely fine)?
I missed the existence of this function. I have made changes to reuse some of it in this commit. 1da772b
Also, from our discussions in #3437 (particularly this #3437 (comment)), it seems we also need to change the target modules for which we're applying LoRA, no?
This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.
From what I can tell is that LoRA checkpoints on the Hub (the most useful ones) from our training script do not have text encoder. So, I think it's fine as is. But if we want to do it in a separate PR with changes to the training script, I am fine with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I just opened #3505
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't have std scaling like the Lora Module: https://github.com/huggingface/diffusers/blob/49ad61c2045a3278ea0b6648546c0824e9d89c0f/src/diffusers/models/attention_processor.py#LL490C56-L491C1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For test purposes, that is not needed.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe use a smaller pipeline like the following?
diffusers/tests/models/test_lora_layers.py
Line 140 in 85eff63
| sd_pipe = StableDiffusionPipeline(**pipeline_components) |
Helps us to run the tests faster but does the job of proper testing at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 1da772b
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very important check!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'd not actually mind referring users to read this issue comment you made:
#3490 (comment)