Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,15 @@ def text_encoder_lora_attn_procs(self):
return self._text_encoder_lora_attn_procs
return

def _remove_text_encoder_monkey_patch(self):
for name, _ in self.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be TEXT_ENCODER_ATTN_MODULE? I see this doesn't get reflected in here too:

if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):

I am asking this because when initializing the LoRA layers for the text encoder, we use TEXT_ENCODER_ATTN_MODULE:

if name.endswith(TEXT_ENCODER_ATTN_MODULE):

Doesn't this create a disparity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is correct because the target of the monkey-patch is not the CLIPAttention itself, but its q/k/v/out_proj attributes. However, as you pointed out, I thought it would be easier to understand if it was aligned with train_dreambooth_lora.py, so I tried refactoring to remove TEXT_ENCODER_TARGET_MODULES. 356a46a WDYT?

module = self.text_encoder.get_submodule(name)
if hasattr(module, "old_forward"):
# restore original `forward` to remove monkey-patch
module.forward = module.old_forward
delattr(module, "old_forward")

def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
Expand All @@ -963,6 +972,10 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
attn_processors: Dict[str, `LoRAAttnProcessor`]:
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
"""

# First, remove any monkey-patch that might have been applied before
self._remove_text_encoder_monkey_patch()

# Loop over the original attention modules.
for name, _ in self.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
Expand All @@ -972,7 +985,9 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
# this forward pass.
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward

# save old_forward to module that can be used to remove monkey-patch
old_forward = module.old_forward = module.forward

# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
Expand Down
40 changes: 40 additions & 0 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,46 @@ def test_text_encoder_lora_monkey_patch(self):
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"

def test_text_encoder_lora_remove_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)

dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)

# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)

# monkey patch
pipe._modify_text_encoder(text_attn_procs)

# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora outputs should be different to without lora outputs"

# remove monkey patch
pipe._remove_text_encoder_monkey_patch()

# inference with removed lora
outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora_removed.shape == (1, 77, 32)

assert torch.allclose(
outputs_without_lora, outputs_without_lora_removed
), "remove lora monkey patch should restore the original outputs"

def create_lora_weight_file(self, tmpdirname):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it, could we rename this test to test_ create_lora_weight_file()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a utility, not a test code. To avoid confusion, I moved its location to where the utility functions are at the top of the file. 5d7939d

_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
Expand Down