Skip to content

[Attention processor] Better warning message when shifting to AttnProcessor2_0 #3457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def set_use_memory_efficient_attention_xformers(
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too much to say that you can use xFormers with LoRA?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca does the following sound good?

"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation (F.scaled_dot_product_attention) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention. In this case, addtionally, if you enable efficient attention via xformers "
"we will support that too."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds a bit confusing because in the first line we essentially say we are ignoring xFormers. Let's keep your original version, I think :) Sorry!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten WDYT? I really want to ensure the users don't feel confused here.

)
else:
try:
Expand All @@ -213,6 +216,9 @@ def set_use_memory_efficient_attention_xformers(
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
print(
f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}"
)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
Expand Down Expand Up @@ -250,6 +256,7 @@ def set_use_memory_efficient_attention_xformers(
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
print("Still defaulting to: AttnProcessor2_0 :O")
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
Expand Down