-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] Repurcussions of using monkey-patching for text encoder LoRA #3621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can't we just copy & store the old function somewhere and re-use it. E.g. as follows: import copy
class A:
def __init__(self):
pass
def a_func(self):
return "Hello"
a = A()
old_func = copy.copy(a.a_func)
def new_func():
return old_func() + " and bye"
a.a_func = new_func
# new function
print(a.a_func())
def next_new_func():
return old_func() + " and new bye"
a.a_func = next_new_func
print(a.a_func()) |
Based on @patrickvonplaten 's idea, I tried writing in a separate branch about temporarily saving to a member called takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch A slightly simplified version did pass. It's difficult... |
@takuma104 my understanding is that when we undo the monkey-patch, it should be: assert torch.allclose(yyy, y) as opposed to assert torch.allclose(yyy, y*2.0) no? This is what I did: def test_monkey_patch_fix_closure():
def monkey_patch(target):
for name, module in target.named_modules():
if isinstance(module, torch.nn.Linear):
print(f'monkey patching to {name}')
if hasattr(module, 'old_forward'):
print('undo monkey-patch')
module.forward = module.old_forward
delattr(module, 'old_forward')
else:
old_forward = module.old_forward = module.forward
def make_new_forward(old_forward):
def new_forward(x):
return old_forward(x) * 2.0
return new_forward
module.forward = make_new_forward(old_forward)
torch.manual_seed(0)
x = torch.randn((2, 2)).to(device)
target = TargetModule().to(device)
with torch.no_grad():
print('')
print('*' * 80)
print('vanilla:')
y = target(x)
print(y)
assert y.shape == (2, 2)
monkey_patch(target)
yy = target(x)
print(yy)
assert torch.allclose(yy, y*2.0), "fix closure: monkey patching failed"
monkey_patch(target)
yyy = target(x)
print(yyy)
assert torch.allclose(yyy, y), "results should match when monkey-patching is removed" Anything I am missing out on? Looking into: takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch |
Here, I think a better test might to be check the outputs directly when monkey-patching is undone, no? I am struggling to understand why there's a need to check with seeds. |
@sayakpaul I had the idea to use the fact that applying monkey_patch twice without correct handling would cause recursive processing and result in different outcomes, but it wasn't a very straightforward approach. I prepared a function called _remove_text_encoder_monkey_patch() and tried to test it. This one passes. If there are no issues with this, shall we merge it into #3437? takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch |
Let's maybe do this in a separate PR to streamline the process. We can do that immediately after merging #3437. Also, just a nit that we need to be very cautious about not using While we're at it, here are some thoughts on takuma104/diffusers@kohya-lora-loader...undoable-monkeypatch:
|
I just opened #3649 for this issue. The link to the diff I posted earlier has been invalidated due to the merge, but the changes have not been modified. |
We can close this since #3649. |
We use monkey-patching to support LoRA fine-tuning of the text encoder and loading of the LoRA params corresponding to the text encoder. While it helps us in preventing additional dependencies like
peft
*, it comes with baggage.As pointed out by @rvorias in #3490 (comment), if we call
pipe.load_lora_weights()
twice (with a checkpoint containing text encoder LoRA parameters), monkey-patching becomes recursive.This should be prevented at all costs.
@rvorias also proposed a fix in takuma104#1 which tackles the issue by introducing a flag
text_encoder_modify_forwards
which controls if LoRA for text encoder should progress or not. Once LoRA is applied for text encoder one-time, this flag is set to False.If one needs to swap the text encoder back to its original state they could do:
Then,
I would prefer this approach.
Calling @patrickvonplaten @pcuenca @takuma104 @rvorias to chime in too :)
The text was updated successfully, but these errors were encountered: