Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,15 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
Comment on lines -950 to +949
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'd not actually mind referring users to read this issue comment you made:
#3490 (comment)

def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment: if you load lora with e.g. pipeline.load_lora_weights("experiments/base_experiment") more than once, then this monkey patch becomes recursive!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! Indeed, it seems that might be the case. I wonder if it might be better to create a mechanism to remove the moneky-patch. @sayakpaul WDYT?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're currently discussing this internally, and will keep y'all posted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile,

Best is indeed to do something else than monkey-patching, but a flag like override_forward=False as an arg would also be helpful to disable the monkey-patching.

@rvorias could you elaborate what you mean here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have load_lora_weights parse a kwarg override_text_encoder_forward:bool, then pass that to _modify_text_encoder. Then condition line 957 on this flag.

This is useful in contexts where you want to load arbitrary lora weights on the fly in a long-running SD inference engine.
Right now, calling load_lora_weights multiple times causes you to override the forward function multiple times and thus the lora addition term will get nested.

If you add the flag+condition you can still have the new lora weights to load, but you don't override the forward again and again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be willing to open a PR for this? We're more than happy to help you with that :-)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return new_forward

# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down
73 changes: 73 additions & 0 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
Expand Down Expand Up @@ -212,3 +213,75 @@ def test_lora_save_load_legacy(self):

# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77

inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things:

  • Could we reuse this function (making changes to it is completely fine)?

def create_text_encoder_lora_layers(text_encoder: nn.Module):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse this function (making changes to it is completely fine)?

I missed the existence of this function. I have made changes to reuse some of it in this commit. 1da772b

Also, from our discussions in #3437 (particularly this #3437 (comment)), it seems we also need to change the target modules for which we're applying LoRA, no?

This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifying might result in losing compatibility with already serialized files and might also require changes to the training code, so it might be better to do it in a separate PR. I'm thinking about opening another draft PR for that.

From what I can tell is that LoRA checkpoints on the Hub (the most useful ones) from our training script do not have text encoder. So, I think it's fine as is. But if we want to do it in a separate PR with changes to the training script, I am fine with that.

Copy link
Contributor Author

@takuma104 takuma104 May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I just opened #3505

text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
# set up.weights
for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test purposes, that is not needed.

if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever!

layer_module.up.weight = torch.nn.Parameter(weight)
text_lora_attn_procs[name] = attn_proc
return text_lora_attn_procs

def test_text_encoder_lora_monkey_patch(self):
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe use a smaller pipeline like the following?

sd_pipe = StableDiffusionPipeline(**pipeline_components)

Helps us to run the tests faster but does the job of proper testing at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 1da772b


dummy_tokens = self.get_dummy_tokens()

# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 768)

# create lora_attn_procs with zeroed out up.weights
text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False)

# monkey patch
pipe._modify_text_encoder(text_lora_attn_procs)

# verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor.
del text_lora_attn_procs
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important check!


# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 768)

assert torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"

# create lora_attn_procs with randn up.weights
text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True)

# monkey patch
pipe._modify_text_encoder(text_lora_attn_procs)

# verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor.
del text_lora_attn_procs
gc.collect()

# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 768)

assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"