-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] feat: add lora attention processor for pt 2.0. #3594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
For benchmarking: I'd benchmark just the attention blocks cross batch sizes and resolutions (maybe forcing a particular scaled dot product impl if you think worthwhile) using the pytorch timer https://pytorch.org/tutorials/recipes/recipes/benchmark.html Though I do think it's ok to merge without benchmarking |
Thanks, @williamberman!
Is it possible to specify that? Could you provide a reference? Edit: I think Will was referring to the different implementations of SDPA as shown here: https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html. |
Some interesting findings in trying to match the performance of SDPA LoRA processor implementation with that of xformers. If we try to explicitly ask the context dispatcher to use flash attention (xformers uses memory-efficient attention which is different), then there's a problem:
And thereby it fails:
However, with In practice, we don't have to explicitly specify which implementation to dispatch to. It's automatically determined. See this for more details. Refer to this Colab Notebook, that consolidates these findings. Earlier the dispatcher was defaulting to the vanilla Cc: @patrickvonplaten @pcuenca -- I think you might find these interesting :) |
Design-wise this looks good to me |
This sounds like a bug within PyTorch, let's maybe contact them. Also does the same error happen on nightly? |
@patrickvonplaten, which part sounds like a bug to you? I have made them aware of this thread (see internal link). |
@patrickvonplaten @williamberman this PR is ready to review. Results seem okay to me: https://wandb.ai/sayakpaul/dreambooth-lora/runs/8u5ys158 @takuma104 would be great if you could also take a look :) |
@@ -261,7 +261,7 @@ def test_lora_save_load(self): | |||
with torch.no_grad(): | |||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | |||
|
|||
assert (sample - new_sample).abs().max() < 1e-4 | |||
assert (sample - new_sample).abs().max() < 5e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of PyTorch SDPA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me!
@@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self): | |||
with torch.no_grad(): | |||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | |||
|
|||
assert (sample - new_sample).abs().max() < 1e-4 | |||
assert (sample - new_sample).abs().max() < 2e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of PyTorch SDPA.
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: | ||
warnings.warn( | ||
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " | ||
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " | ||
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " | ||
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " | ||
"native efficient flash attention." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided to remove this rather confusing warning message. But LMK if you think otherwise.
We still want our users to take advantage of xformers for LoRA, Custom Diffusion, etc. even when the rest of the attention processors run with SDPA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! In my experiments xformers is still also sometimes faster and more memory efficient
@@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers( | |||
raise e | |||
|
|||
if is_lora: | |||
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers | |||
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually think for now, let's give the user full freedom over what to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean!
Don't worry about the test failures, they are easy to fix. I will get them fixed. |
@sayakpaul Looks nice for me! I tried writing and running a simple integration test code with Kohya-ss LoRA, but it resulted in a
|
Thanks @takuma104! I will take care of it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
@takuma104 I ran your script and here's the output: {"width": 512, "height": 768, "batch": 4, "xformers": "OFF", "lora": "OFF", "mem_MB": 5402}
{"width": 512, "height": 768, "batch": 4, "xformers": "OFF", "lora": "ON", "mem_MB": 5471}
{"width": 512, "height": 768, "batch": 4, "xformers": "ON", "lora": "OFF", "mem_MB": 5401}
{"width": 512, "height": 768, "batch": 4, "xformers": "ON", "lora": "ON", "mem_MB": 5471} |
* feat: add lora attention processor for pt 2.0. * explicit context manager for SDPA. * switch to flash attention * make shapes compatible to work optimally with SDPA. * fix: circular import problem. * explicitly specify the flash attention kernel in sdpa * fall back to efficient attention context manager. * remove explicit dispatch. * fix: removed processor. * fix: remove optional from type annotation. * feat: make changes regarding LoRAAttnProcessor2_0. * remove confusing warning. * formatting. * relax tolerance for PT 2.0 * fix: loading message. * remove unnecessary logging. * add: entry to the docs. * add: network_alpha argument. * relax tolerance.
* feat: add lora attention processor for pt 2.0. * explicit context manager for SDPA. * switch to flash attention * make shapes compatible to work optimally with SDPA. * fix: circular import problem. * explicitly specify the flash attention kernel in sdpa * fall back to efficient attention context manager. * remove explicit dispatch. * fix: removed processor. * fix: remove optional from type annotation. * feat: make changes regarding LoRAAttnProcessor2_0. * remove confusing warning. * formatting. * relax tolerance for PT 2.0 * fix: loading message. * remove unnecessary logging. * add: entry to the docs. * add: network_alpha argument. * relax tolerance.
Part of #3464.
This PR adds a PT 2.0 variant of the
LoRAAttnProcessor
utilizing the memory-efficient scaled-dot product attention.If the design looks good, then the following TODOs remain:
LoRAAttnProcessor2_0
should be used.loaders.py
as needed.However, we should be aware of the following observation.
First, I ran a benchmark with the regular LoRA:
Prints:
When I utilized the
LoRAAttnProcessor2_0
like so:leads to:
Why the execution timing is more for the PT 2.0 LoRA attention processor? I did manually verify if the attention processor classes of the
unet
was indeedLoRAAttnProcessor2_0
or not after the assignmentpipe.unet = unet
. Is there something I am missing out in the tests?