Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,14 +943,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
if name in attn_processors:
module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
module.old_forward = module.forward

# Monkey-patch.
module.forward = new_forward
def new_forward(self, x):
return self.old_forward(x) + self.lora_layer(x)

# Monkey-patch.
module.forward = new_forward.__get__(module)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down
68 changes: 67 additions & 1 deletion tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
Expand All @@ -22,7 +23,7 @@

from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device


Expand Down Expand Up @@ -212,3 +213,68 @@ def test_lora_save_load_legacy(self):

# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77

inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda")

prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def test_text_encoder_lora_monkey_patch(self):
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")

dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 768)

# create lora_attn_procs with zeroed out up.weights
text_lora_attn_procs = {}
for name, module in pipe.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda")

# make sure that the up.weights are zeroed out
for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
assert torch.allclose(
layer_module.up.weight, torch.zeros_like(layer_module.up.weight)
), "lora_up_weight should be zeroed out"

text_lora_attn_procs[name] = attn_proc

# monkey patch
pipe._modify_text_encoder(text_lora_attn_procs)

# verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor.
del text_lora_attn_procs
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important check!


# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 768)

assert torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"

# set randn to lora_up.weights
for name, _ in pipe.text_encoder.named_modules():
if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES):
module = pipe.text_encoder.get_submodule(name)
assert hasattr(module, "lora_layer"), "lora_layer should be added"
assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer"
module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight))

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 768)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"