Skip to content

Commit bb9c61e

Browse files
committed
fix monkey-patch for text_encoder
1 parent 8858ebb commit bb9c61e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/diffusers/loaders.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -946,14 +946,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
946946
module = self.text_encoder.get_submodule(name)
947947
# Construct a new function that performs the LoRA merging. We will monkey patch
948948
# this forward pass.
949-
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
950-
old_forward = module.forward
951949

952-
def new_forward(x):
953-
return old_forward(x) + lora_layer(x)
950+
if name in attn_processors:
951+
module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
952+
module.old_forward = module.forward
954953

955-
# Monkey-patch.
956-
module.forward = new_forward
954+
def new_forward(self, x):
955+
return self.old_forward(x) + self.lora_layer(x)
956+
957+
# Monkey-patch.
958+
module.forward = new_forward.__get__(module)
957959

958960
def _get_lora_layer_attribute(self, name: str) -> str:
959961
if "q_proj" in name:

0 commit comments

Comments
 (0)