Skip to content

[LoRA] feat: add lora attention processor for pt 2.0. #3594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 6, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented May 29, 2023

Part of #3464.

This PR adds a PT 2.0 variant of the LoRAAttnProcessor utilizing the memory-efficient scaled-dot product attention.

If the design looks good, then the following TODOs remain:

However, we should be aware of the following observation.

First, I ran a benchmark with the regular LoRA:

from huggingface_hub.repocard import RepoCard
from diffusers import StableDiffusionPipeline
import torch
import time

lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(lora_model_id)


start_time = time.time_ns()
for _ in range(10):
    image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
end_time =  time.time_ns()
print(f"Execution time -- {(end_time - start_time) / 1e6:.1f} ms")

Prints:

Execution time -- 21498.7 ms

When I utilized the LoRAAttnProcessor2_0 like so:

from diffusers.models.attention_processor import LoRAAttnProcessor2_0
from diffusers import StableDiffusionPipeline
import torch
import time

model_id = "runwayml/stable-diffusion-v1-5"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
unet = pipe.unet

# Set correct lora layers
unet_lora_attn_procs = {}
for name, attn_processor in unet.attn_processors.items():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    unet_lora_attn_procs[name] = LoRAAttnProcessor2_0(
        hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
    ).to(unet.device)

unet.set_attn_processor(unet_lora_attn_procs)
pipe.unet = unet


start_time = time.time_ns()
for _ in range(10):
    image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
end_time =  time.time_ns()
print(f"Execution time -- {(end_time - start_time) / 1e6:.1f} ms")

leads to:

Execution time -- 23202.8 ms

Why the execution timing is more for the PT 2.0 LoRA attention processor? I did manually verify if the attention processor classes of the unet was indeed LoRAAttnProcessor2_0 or not after the assignment pipe.unet = unet. Is there something I am missing out in the tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 29, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman
Copy link
Contributor

For benchmarking: I'd benchmark just the attention blocks cross batch sizes and resolutions (maybe forcing a particular scaled dot product impl if you think worthwhile) using the pytorch timer https://pytorch.org/tutorials/recipes/recipes/benchmark.html

Though I do think it's ok to merge without benchmarking

@sayakpaul
Copy link
Member Author

sayakpaul commented May 30, 2023

Thanks, @williamberman!

maybe forcing a particular scaled dot product impl

Is it possible to specify that? Could you provide a reference?

Edit: I think Will was referring to the different implementations of SDPA as shown here: https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html.

@sayakpaul
Copy link
Member Author

Some interesting findings in trying to match the performance of SDPA LoRA processor implementation with that of xformers.

If we try to explicitly ask the context dispatcher to use flash attention (xformers uses memory-efficient attention which is different), then there's a problem:

/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1305: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
  hidden_states = F.scaled_dot_product_attention(
/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1305: UserWarning: Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128. Got Query.size(-1): 160, Key.size(-1): 160, Value.size(-1): 160 instead. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:253.)
  hidden_states = F.scaled_dot_product_attention(

And thereby it fails:

RuntimeError: No available kernel.  Aborting execution.

However, with SDPBackend.EFFICIENT_ATTENTION, it runs as expected and matches the performance of xformers.

In practice, we don't have to explicitly specify which implementation to dispatch to. It's automatically determined. See this for more details.

Refer to this Colab Notebook, that consolidates these findings.

Earlier the dispatcher was defaulting to the vanilla MATH implementation because the data shapes were not conforming to the requirements of efficient attention. Once that was fixed, the performance issue reported here got resolved.

Cc: @patrickvonplaten @pcuenca -- I think you might find these interesting :)

@patrickvonplaten
Copy link
Contributor

Design-wise this looks good to me

@patrickvonplaten
Copy link
Contributor

Some interesting findings in trying to match the performance of SDPA LoRA processor implementation with that of xformers.

If we try to explicitly ask the context dispatcher to use flash attention (xformers uses memory-efficient attention which is different), then there's a problem:

/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1305: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
  hidden_states = F.scaled_dot_product_attention(
/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1305: UserWarning: Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128. Got Query.size(-1): 160, Key.size(-1): 160, Value.size(-1): 160 instead. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:253.)
  hidden_states = F.scaled_dot_product_attention(

And thereby it fails:

RuntimeError: No available kernel.  Aborting execution.

However, with SDPBackend.EFFICIENT_ATTENTION, it runs as expected and matches the performance of xformers.

In practice, we don't have to explicitly specify which implementation to dispatch to. It's automatically determined. See this for more details.

Refer to this Colab Notebook, that consolidates these findings.

Earlier the dispatcher was defaulting to the vanilla MATH implementation because the data shapes were not conforming to the requirements of efficient attention. Once that was fixed, the performance issue reported here got resolved.

Cc: @patrickvonplaten @pcuenca -- I think you might find these interesting :)

This sounds like a bug within PyTorch, let's maybe contact them. Also does the same error happen on nightly?

@sayakpaul
Copy link
Member Author

This sounds like a bug within PyTorch, let's maybe contact them. Also does the same error happen on nightly?

@patrickvonplaten, which part sounds like a bug to you? I have made them aware of this thread (see internal link).

@sayakpaul sayakpaul marked this pull request as ready for review June 2, 2023 06:16
@sayakpaul
Copy link
Member Author

@patrickvonplaten @williamberman this PR is ready to review.

Results seem okay to me: https://wandb.ai/sayakpaul/dreambooth-lora/runs/8u5ys158

@takuma104 would be great if you could also take a look :)

@@ -261,7 +261,7 @@ def test_lora_save_load(self):
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4
assert (sample - new_sample).abs().max() < 5e-4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of PyTorch SDPA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me!

@@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self):
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4
assert (sample - new_sample).abs().max() < 2e-4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of PyTorch SDPA.

Comment on lines -203 to -210
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to remove this rather confusing warning message. But LMK if you think otherwise.

We still want our users to take advantage of xformers for LoRA, Custom Diffusion, etc. even when the rest of the attention processors run with SDPA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! In my experiments xformers is still also sometimes faster and more memory efficient

@@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers(
raise e

if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually think for now, let's give the user full freedom over what to use

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean!

@sayakpaul
Copy link
Member Author

Don't worry about the test failures, they are easy to fix. I will get them fixed.

@takuma104
Copy link
Contributor

@sayakpaul Looks nice for me! I tried writing and running a simple integration test code with Kohya-ss LoRA, but it resulted in a network_alpha argument error. Could you add this argument when merging the main branch?

TypeError: LoRAAttnProcessor2_0.__init__() got an unexpected keyword argument 'network_alpha'

@sayakpaul
Copy link
Member Author

Thanks @takuma104! I will take care of it.

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@sayakpaul
Copy link
Member Author

@takuma104 I ran your script and here's the output:

{"width": 512, "height": 768, "batch": 4, "xformers": "OFF", "lora": "OFF", "mem_MB": 5402}
{"width": 512, "height": 768, "batch": 4, "xformers": "OFF", "lora": "ON", "mem_MB": 5471}
{"width": 512, "height": 768, "batch": 4, "xformers": "ON", "lora": "OFF", "mem_MB": 5401}
{"width": 512, "height": 768, "batch": 4, "xformers": "ON", "lora": "ON", "mem_MB": 5471}

@sayakpaul sayakpaul merged commit 8669e83 into main Jun 6, 2023
@sayakpaul sayakpaul deleted the feat/lora-attn-pt2 branch June 6, 2023 09:26
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* feat: add lora attention processor for pt 2.0.

* explicit context manager for SDPA.

* switch to flash attention

* make shapes compatible to work optimally with SDPA.

* fix: circular import problem.

* explicitly specify the flash attention kernel in sdpa

* fall back to efficient attention context manager.

* remove explicit dispatch.

* fix: removed processor.

* fix: remove optional from type annotation.

* feat: make changes regarding LoRAAttnProcessor2_0.

* remove confusing warning.

* formatting.

* relax tolerance for PT 2.0

* fix: loading message.

* remove unnecessary logging.

* add: entry to the docs.

* add: network_alpha argument.

* relax tolerance.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* feat: add lora attention processor for pt 2.0.

* explicit context manager for SDPA.

* switch to flash attention

* make shapes compatible to work optimally with SDPA.

* fix: circular import problem.

* explicitly specify the flash attention kernel in sdpa

* fall back to efficient attention context manager.

* remove explicit dispatch.

* fix: removed processor.

* fix: remove optional from type annotation.

* feat: make changes regarding LoRAAttnProcessor2_0.

* remove confusing warning.

* formatting.

* relax tolerance for PT 2.0

* fix: loading message.

* remove unnecessary logging.

* add: entry to the docs.

* add: network_alpha argument.

* relax tolerance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants