Skip to content

Fix monkey-patch for text_encoder LoRA #3490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

takuma104
Copy link
Contributor

Related Issue #3445. and PR #3437. See #3437 (comment)

This change, which I believe should not be necessary in principle, is one that I would like to revert once I understand the underlying cause and confirm that it's not a problem.

When I investigated the cause of the severe output distortion when enabling LoRA for the text_encoder, I found that the output was still distorted even when I removed the + lora_layer(x) part. Therefore, I concluded that this monkey-patch itself was the cause. As a countermeasure, I tried making it an instance method, and for some reason it worked, although I'm not entirely sure why it was fixed.

I compared the original code equivalent with the method of making it an instance method in simplified test code, but both of these tests pass without any problems.
https://gist.github.com/takuma104/1263383cdab8f54bb14f389facdbe960

I suspect that it might be a memory-related issue with Python or PyTorch garbage collection, but I'm not sure yet. I'm thinking of adjusting the test a bit closer to the current situation and investigating.

Workflow plan:

#3437 (comment)

  • Monkey-patch fix (looks like it's fixed now)
  • Revisit the tests for that (I will take care of that)
  • Qualitative testing (we both can do this)
  • DreamBooth LoRA with text encoder test (qualitative)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! @sayakpaul I'd be in favor of merging this one here as is to fix some existing issues. Ok for you?

@patrickvonplaten
Copy link
Contributor

Just one test seems to be failing cc @sayakpaul as well

@takuma104
Copy link
Contributor Author

I've been able to simplify and compare the vanilla and instance_method tests a bit more. It seems that things change when looping more than twice in a for loop? Well, it's still unclear what this means.
https://gist.github.com/takuma104/894dff4e48a7e1dbebedcff136da5956

Also, it seems that whether it's cuda or cpu doesn't matter, so I'm thinking of making the execution device for the current test_text_encoder_lora_monkey_patch() a cpu. What do you think?

@takuma104
Copy link
Contributor Author

I showed the code to Professor GPT-4 and got some advice. This makes sense.

The issue with the code is related to how Python's closures work, combined with how the Python for loop and the def keyword operate.

In the monkey_patch function, you are creating a new function new_forward which calls old_forward(x). Here old_forward is a variable defined in the loop that creates the new_forward function, and it is stored in the function's closure when the function is defined.

In Python, variables in a closure are not copies, they are references to the original variable. When the loop continues to the next module, the old_forward variable is overwritten with the new module's forward function. This means that all of the new_forward functions created in the loop will call the forward function of the last module in the loop, because they all reference the same old_forward variable that has been overwritten.

Here's how you could solve this problem: you could create a new scope that locks in the old_forward value for each new_forward function. A convenient way to create a new scope in Python is by using a function. Here is an example of how you could modify the monkey_patch function:

def test_monkey_patch_vanilla():
    def monkey_patch(target):
        for name, module in target.named_modules():
            if isinstance(module, torch.nn.Linear):
                print(f'monkey patching to {name}')
                old_forward = module.forward
                
                def make_new_forward(old_forward):
                    def new_forward(x):
                        return old_forward(x)
                    return new_forward

                module.forward = make_new_forward(old_forward)

In this version, make_new_forward creates a new scope for each new_forward function. The old_forward passed to make_new_forward becomes part of the new_forward function's closure, and this old_forward is not overwritten by the loop because it's a parameter to the function, not a variable in the loop scope.

@takuma104
Copy link
Contributor Author

I revised it to the Closure version because it is closer to the original code. I've also adjusted the tests accordingly. In addition, I made it run on the default (cpu) device.

@sayakpaul
Copy link
Member

@takuma104, brilliant stuff!

Using GPT-4 for that explanation is also quite clever IMO. May we know the prompt you used?

Comment on lines -950 to +949
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'd not actually mind referring users to read this issue comment you made:
#3490 (comment)

prepared_inputs["input_ids"] = inputs
return prepared_inputs

def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things:

  • Could we reuse this function (making changes to it is completely fine)?

def create_text_encoder_lora_layers(text_encoder: nn.Module):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse this function (making changes to it is completely fine)?

I missed the existence of this function. I have made changes to reuse some of it in this commit. 1da772b

Also, from our discussions in #3437 (particularly this #3437 (comment)), it seems we also need to change the target modules for which we're applying LoRA, no?

This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.

From what I can tell is that LoRA checkpoints on the Hub (the most useful ones) from our training script do not have text encoder. So, I think it's fine as is. But if we want to do it in a separate PR with changes to the training script, I am fine with that.

Copy link
Contributor Author

@takuma104 takuma104 May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I just opened #3505

return text_lora_attn_procs

def test_text_encoder_lora_monkey_patch(self):
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe use a smaller pipeline like the following?

sd_pipe = StableDiffusionPipeline(**pipeline_components)

Helps us to run the tests faster but does the job of proper testing at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 1da772b

Comment on lines 259 to 261
# verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor.
del text_lora_attn_procs
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important check!

Comment on lines 234 to 239
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks so much for your help and work here, @takuma104!

I just left a clarification question around the application of the LoRA layers to the text encoder attention blocks. Once that is sorted, there are very minor things to fix! I can take care of them :)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks so much for your help and work here, @takuma104!

I just left a clarification question around the application of the LoRA layers to the text encoder attention blocks. Once that is sorted, there are very minor things to fix! I can take care of them :)

@rvorias
Copy link

rvorias commented May 22, 2023

Have you guys checked that the added lora actually get updated during training? I've added the monkey patch correction but still seem to miss something.

for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test purposes, that is not needed.

@sayakpaul
Copy link
Member

Have you guys checked that the added lora actually get updated during training? I've added the monkey patch correction but still seem to miss something.

Not yet since there's an open question we're addressing. After that we will proceed to training and seeing if things are working as expected.

Meanwhile, would be great if you could share your setup for the test you conducted.

@rvorias
Copy link

rvorias commented May 22, 2023

Meanwhile, would be great if you could share your setup for the test you conducted.

train_text_encoder=True
removed unet.parameters() from optim
changed with accelerate.accumulate(unet) to with accelerate.accumulate(text_encoder) just to be sure
added closure fix for monkey-patching
added custom lr specifically for the text_encoder

-> no updates happening

This is what I get from inspecting the produced weights .bin:
(last two LoRAs)

'text_encoder.text_model.encoder.layers.22.self_attn.out_proj.to_v_lora.down.weight'
array([[ 0.21, -0.02,  0.39, ..., -0.42,  0.05,  0.41],
       [ 0.09, -0.06,  0.27, ..., -0.26, -0.09,  0.04],
       [ 0.19, -0.08, -0.21, ..., -0.19, -0.1 , -0.22],
       [ 0.17,  0.26,  0.02, ...,  0.04,  0.03, -0.06]], dtype=float32)
'text_encoder.text_model.encoder.layers.22.self_attn.out_proj.to_v_lora.up.weight'
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)
'text_encoder.text_model.encoder.layers.22.self_attn.out_proj.to_out_lora.down.weight'
array([[-0.07,  0.16, -0.55, ..., -0.24,  0.05, -0.3 ],
       [ 0.22,  0.28, -0.02, ...,  0.37, -0.  ,  0.36],
       [ 0.15,  0.29, -0.15, ..., -0.06,  0.07, -0.16],
       [ 0.01, -0.14,  0.4 , ..., -0.08, -0.06,  0.35]], dtype=float32)
'text_encoder.text_model.encoder.layers.22.self_attn.out_proj.to_out_lora.up.weight'
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

I might not fully understand the implementation, but it seems that each k, v, q, outproj get a k, v, q and outproj lora layer.

Also, LoRAAttnProcessor is a full implementation of original attn + lora, so doesn't the new function should just be:

def create_new_forward():
    def new_forward(x):
        return lora_layer(x)
    return new_forward

@sayakpaul
Copy link
Member

Thanks for providing the additional context!

Which script are you using? You should be looking into the train_dreambooth_lora.py script. There's no involvement of the unet.parameters() there.

I might not fully understand the implementation, but it seems that each k, v, q, outproj get a k, v, q and outproj lora layer.

Also, LoRAAttnProcessor is a full implementation of original attn + lora, so doesn't the new function should just be:

From here, we merge the learned LoRA parameters with the corresponding attention modules. So, we need to be able to perform the merging like the way we're currently doing. Hope this clarifies your doubt.

@rvorias
Copy link

rvorias commented May 22, 2023

You're right, I meant unet_lora_layers.parameters().

@patrickvonplaten
Copy link
Contributor

@sayakpaul, feel free to merge once you think things are looking good

@takuma104
Copy link
Contributor Author

@sayakpaul Thanks!

Using GPT-4 for that explanation is also quite clever IMO. May we know the prompt you used?

I just simplified the test code a bit and copied and pasted it as is. Right now, GPT-4 translating this sentence into English for me, but I can't imagine life without GPT-4 anymore ;)

prompt:

The following code stops at the last line due to an assert. Where is the issue?

import torch
import torch.nn

device = 'cuda' # Changing this to 'cpu' still results in the vanilla case failing.

class TargetModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = torch.nn.ModuleList([
            torch.nn.Linear(2, 2),
            torch.nn.Linear(2, 2), # If you comment out this line, the vanilla case will succeed.
        ])
    def forward(self, x):
        for module in self.linears:
            x = module(x)
        return x

def test_monkey_patch_vanilla():
    def monkey_patch(target):
        for name, module in target.named_modules():
            if isinstance(module, torch.nn.Linear):
                print(f'monkey patching to {name}')
                old_forward = module.forward
                def new_forward(x):
                    return old_forward(x)
                module.forward = new_forward

    torch.manual_seed(0)
    x = torch.randn((2, 2)).to(device)
    target = TargetModule().to(device)
    with torch.no_grad():
        print('')
        print('*' * 80)
        print('vanilla:')

        y = target(x)
        print(y)
        assert y.shape == (2, 2)

        monkey_patch(target)

        yy = target(x)
        print(yy)
        assert torch.allclose(yy, y), "vanilla: monkey patching failed"

@sayakpaul
Copy link
Member

sayakpaul commented May 22, 2023

#3490 (comment)

@takuma104 that was a simpler prompt than I would have expected. Damn!

@sayakpaul
Copy link
Member

I will review this PR again tomorrow in detail.

@sayakpaul
Copy link
Member

#3505 (comment) :)

# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment: if you load lora with e.g. pipeline.load_lora_weights("experiments/base_experiment") more than once, then this monkey patch becomes recursive!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! Indeed, it seems that might be the case. I wonder if it might be better to create a mechanism to remove the moneky-patch. @sayakpaul WDYT?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're currently discussing this internally, and will keep y'all posted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile,

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

@rvorias could you elaborate what you mean here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have load_lora_weights parse a kwarg override_text_encoder_forward:bool, then pass that to _modify_text_encoder. Then condition line 957 on this flag.

This is useful in contexts where you want to load arbitrary lora weights on the fly in a long-running SD inference engine.
Right now, calling load_lora_weights multiple times causes you to override the forward function multiple times and thus the lora addition term will get nested.

If you add the flag+condition you can still have the new lora weights to load, but you don't override the forward again and again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be willing to open a PR for this? We're more than happy to help you with that :-)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
Copy link
Member

Closing in favor of #3437.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants