Skip to content

[LoRA] Repurcussions of using monkey-patching for text encoder LoRA #3621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sayakpaul opened this issue May 31, 2023 · 9 comments
Closed

[LoRA] Repurcussions of using monkey-patching for text encoder LoRA #3621

sayakpaul opened this issue May 31, 2023 · 9 comments

Comments

@sayakpaul
Copy link
Member

We use monkey-patching to support LoRA fine-tuning of the text encoder and loading of the LoRA params corresponding to the text encoder. While it helps us in preventing additional dependencies like peft*, it comes with baggage.

As pointed out by @rvorias in #3490 (comment), if we call pipe.load_lora_weights() twice (with a checkpoint containing text encoder LoRA parameters), monkey-patching becomes recursive.

This should be prevented at all costs.

@rvorias also proposed a fix in takuma104#1 which tackles the issue by introducing a flag text_encoder_modify_forwards which controls if LoRA for text encoder should progress or not. Once LoRA is applied for text encoder one-time, this flag is set to False.

If one needs to swap the text encoder back to its original state they could do:

del pipe.text_encoder

Then,

text_encoder = CLIPTextModel(ckpt_id, subfolder="text_encoder", torch_dtype=pipe.torch_dtype).to(pipe.device)

pipe.text_encoder = text_encoder

I would prefer this approach.

Calling @patrickvonplaten @pcuenca @takuma104 @rvorias to chime in too :)

@patrickvonplaten
Copy link
Contributor

Can't we just copy & store the old function somewhere and re-use it. E.g. as follows:

import copy

class A:

    def __init__(self):
        pass

    def a_func(self):
        return "Hello"


a = A()

old_func = copy.copy(a.a_func)

def new_func():
    return old_func() + " and bye"

a.a_func = new_func

# new function
print(a.a_func())

def next_new_func():
    return old_func() + " and new bye"

a.a_func = next_new_func

print(a.a_func())

@takuma104
Copy link
Contributor

Based on @patrickvonplaten 's idea, I tried writing in a separate branch about temporarily saving to a member called old_forward. I also wrote a test, but unfortunately, it doesn't pass for some reason. Hmm.

takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch

A slightly simplified version did pass. It's difficult...
https://gist.github.com/takuma104/93094f989ee89e4cd61af09f9d909e26

@sayakpaul
Copy link
Member Author

A slightly simplified version did pass. It's difficult...
https://gist.github.com/takuma104/93094f989ee89e4cd61af09f9d909e26

@takuma104 my understanding is that when we undo the monkey-patch, it should be:

assert torch.allclose(yyy, y)

as opposed to

assert torch.allclose(yyy, y*2.0)

no?

This is what I did:

def test_monkey_patch_fix_closure():
    def monkey_patch(target):
        for name, module in target.named_modules():
            if isinstance(module, torch.nn.Linear):
                print(f'monkey patching to {name}')

                if hasattr(module, 'old_forward'):
                    print('undo monkey-patch')
                    module.forward = module.old_forward
                    delattr(module, 'old_forward')
                else:
                    old_forward = module.old_forward = module.forward
                    def make_new_forward(old_forward):
                        def new_forward(x):
                            return old_forward(x) * 2.0
                        return new_forward
                    module.forward = make_new_forward(old_forward)

    torch.manual_seed(0)
    x = torch.randn((2, 2)).to(device)
    target = TargetModule().to(device)
    with torch.no_grad():
        print('')
        print('*' * 80)
        print('vanilla:')

        y = target(x)
        print(y)
        assert y.shape == (2, 2)

        monkey_patch(target)

        yy = target(x)
        print(yy)
        assert torch.allclose(yy, y*2.0), "fix closure: monkey patching failed"

        monkey_patch(target)

        yyy = target(x)
        print(yyy)
        assert torch.allclose(yyy, y), "results should match when monkey-patching is removed"

Anything I am missing out on?

Looking into: takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch

@sayakpaul
Copy link
Member Author

Looking into: takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch

Here, I think a better test might to be check the outputs directly when monkey-patching is undone, no? I am struggling to understand why there's a need to check with seeds.

@takuma104
Copy link
Contributor

@sayakpaul I had the idea to use the fact that applying monkey_patch twice without correct handling would cause recursive processing and result in different outcomes, but it wasn't a very straightforward approach. I prepared a function called _remove_text_encoder_monkey_patch() and tried to test it. This one passes. If there are no issues with this, shall we merge it into #3437?

takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch

@sayakpaul
Copy link
Member Author

@sayakpaul I had the idea to use the fact that applying monkey_patch twice without correct handling would cause recursive processing and result in different outcomes, but it wasn't a very straightforward approach. I prepared a function called _remove_text_encoder_monkey_patch() and tried to test it. This one passes. If there are no issues with this, shall we merge it into #3437?

takuma104/[email protected]

Let's maybe do this in a separate PR to streamline the process. We can do that immediately after merging #3437.

Also, just a nit that we need to be very cautious about not using TEXT_ENCODER_TARGET_MODULES and using TEXT_ENCODER_ATTN_MODULE as you had correctly figured out in #3437.

While we're at it, here are some thoughts on takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch:

  • _remove_text_encoder_monkey_patch() sounds like a nice utility to me -- @patrickvonplaten WDYT?
  • We just need to figure out a nice way of letting the user know when they call load_lora_weights() twice to make them fully aware of the consequences. WDYT?

@takuma104
Copy link
Contributor

I just opened #3649 for this issue. The link to the diff I posted earlier has been invalidated due to the merge, but the changes have not been modified.

@patrickvonplaten
Copy link
Contributor

I like #3649 ! I think for this special case it's important though that the user doesn't have to understand what is happening under the surface. Instead we should have certainty that new loras can be loaded easily without causing any problems. To me it seems like #3649 does solve this

@sayakpaul
Copy link
Member Author

We can close this since #3649.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants