-
Notifications
You must be signed in to change notification settings - Fork 6k
Fix monkey-patch for text_encoder LoRA #3490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix monkey-patch for text_encoder LoRA #3490
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! @sayakpaul I'd be in favor of merging this one here as is to fix some existing issues. Ok for you?
Just one test seems to be failing cc @sayakpaul as well |
I've been able to simplify and compare the vanilla and instance_method tests a bit more. It seems that things change when looping more than twice in a for loop? Well, it's still unclear what this means. Also, it seems that whether it's cuda or cpu doesn't matter, so I'm thinking of making the execution device for the current |
I showed the code to Professor GPT-4 and got some advice. This makes sense.
def test_monkey_patch_vanilla():
def monkey_patch(target):
for name, module in target.named_modules():
if isinstance(module, torch.nn.Linear):
print(f'monkey patching to {name}')
old_forward = module.forward
def make_new_forward(old_forward):
def new_forward(x):
return old_forward(x)
return new_forward
module.forward = make_new_forward(old_forward)
|
I revised it to the Closure version because it is closer to the original code. I've also adjusted the tests accordingly. In addition, I made it run on the default (cpu) device. |
@takuma104, brilliant stuff! Using GPT-4 for that explanation is also quite clever IMO. May we know the prompt you used? |
return old_forward(x) + lora_layer(x) | ||
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'd not actually mind referring users to read this issue comment you made:
#3490 (comment)
tests/models/test_lora_layers.py
Outdated
prepared_inputs["input_ids"] = inputs | ||
return prepared_inputs | ||
|
||
def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things:
- Could we reuse this function (making changes to it is completely fine)?
def create_text_encoder_lora_layers(text_encoder: nn.Module): |
- Also, from our discussions in Support Kohya-ss style LoRA file format (in a limited capacity) #3437 (particularly this comment), it seems we also need to change the target modules for which we're applying LoRA, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we reuse this function (making changes to it is completely fine)?
I missed the existence of this function. I have made changes to reuse some of it in this commit. 1da772b
Also, from our discussions in #3437 (particularly this #3437 (comment)), it seems we also need to change the target modules for which we're applying LoRA, no?
This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.
From what I can tell is that LoRA checkpoints on the Hub (the most useful ones) from our training script do not have text encoder. So, I think it's fine as is. But if we want to do it in a separate PR with changes to the training script, I am fine with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I just opened #3505
tests/models/test_lora_layers.py
Outdated
return text_lora_attn_procs | ||
|
||
def test_text_encoder_lora_monkey_patch(self): | ||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe use a smaller pipeline like the following?
diffusers/tests/models/test_lora_layers.py
Line 140 in 85eff63
sd_pipe = StableDiffusionPipeline(**pipeline_components) |
Helps us to run the tests faster but does the job of proper testing at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 1da772b
tests/models/test_lora_layers.py
Outdated
# verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. | ||
del text_lora_attn_procs | ||
gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very important check!
tests/models/test_lora_layers.py
Outdated
if layer_name.endswith("_lora"): | ||
weight = ( | ||
torch.randn_like(layer_module.up.weight) | ||
if randn_weight | ||
else torch.zeros_like(layer_module.up.weight) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks so much for your help and work here, @takuma104!
I just left a clarification question around the application of the LoRA layers to the text encoder attention blocks. Once that is sorted, there are very minor things to fix! I can take care of them :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks so much for your help and work here, @takuma104!
I just left a clarification question around the application of the LoRA layers to the text encoder attention blocks. Once that is sorted, there are very minor things to fix! I can take care of them :)
Have you guys checked that the added lora actually get updated during training? I've added the monkey patch correction but still seem to miss something. |
tests/models/test_lora_layers.py
Outdated
for layer_name, layer_module in attn_proc.named_modules(): | ||
if layer_name.endswith("_lora"): | ||
weight = ( | ||
torch.randn_like(layer_module.up.weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't have std scaling like the Lora Module: https://github.com/huggingface/diffusers/blob/49ad61c2045a3278ea0b6648546c0824e9d89c0f/src/diffusers/models/attention_processor.py#LL490C56-L491C1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For test purposes, that is not needed.
Not yet since there's an open question we're addressing. After that we will proceed to training and seeing if things are working as expected. Meanwhile, would be great if you could share your setup for the test you conducted. |
-> no updates happening This is what I get from inspecting the produced weights .bin:
I might not fully understand the implementation, but it seems that each k, v, q, outproj get a k, v, q and outproj lora layer. Also, def create_new_forward():
def new_forward(x):
return lora_layer(x)
return new_forward |
Thanks for providing the additional context! Which script are you using? You should be looking into the
From here, we merge the learned LoRA parameters with the corresponding attention modules. So, we need to be able to perform the merging like the way we're currently doing. Hope this clarifies your doubt. |
You're right, I meant |
@sayakpaul, feel free to merge once you think things are looking good |
@sayakpaul Thanks!
I just simplified the test code a bit and copied and pasted it as is. Right now, GPT-4 translating this sentence into English for me, but I can't imagine life without GPT-4 anymore ;) prompt:
|
@takuma104 that was a simpler prompt than I would have expected. Damn! |
I will review this PR again tomorrow in detail. |
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function | ||
def make_new_forward(old_forward, lora_layer): | ||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comment: if you load lora with e.g. pipeline.load_lora_weights("experiments/base_experiment")
more than once, then this monkey patch becomes recursive!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion! Indeed, it seems that might be the case. I wonder if it might be better to create a mechanism to remove the moneky-patch. @sayakpaul WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best is indeed to do something else than monkey-patching, but a flag like override_forward=False
as an arg would also be helpful to disable the monkey-patching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're currently discussing this internally, and will keep y'all posted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile,
Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.
@rvorias could you elaborate what you mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have load_lora_weights
parse a kwarg override_text_encoder_forward:bool
, then pass that to _modify_text_encoder
. Then condition line 957 on this flag.
This is useful in contexts where you want to load arbitrary lora weights on the fly in a long-running SD inference engine.
Right now, calling load_lora_weights
multiple times causes you to override the forward function multiple times and thus the lora addition term will get nested.
If you add the flag+condition you can still have the new lora weights to load, but you don't override the forward again and again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you be willing to open a PR for this? We're more than happy to help you with that :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Closing in favor of #3437. |
Related Issue #3445. and PR #3437. See #3437 (comment)
Workflow plan:
#3437 (comment)