Skip to content

Commit 8d49e5d

Browse files
sayakpaulJimmy
authored andcommitted
[Attention processor] Better warning message when shifting to AttnProcessor2_0 (huggingface#3457)
* add: debugging to enabling memory efficient processing * add: better warning message.
1 parent c9019e9 commit 8d49e5d

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,10 @@ def set_use_memory_efficient_attention_xformers(
191191
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
192192
warnings.warn(
193193
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
194-
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
194+
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
195+
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
196+
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
197+
"native efficient flash attention."
195198
)
196199
else:
197200
try:
@@ -213,6 +216,9 @@ def set_use_memory_efficient_attention_xformers(
213216
)
214217
processor.load_state_dict(self.processor.state_dict())
215218
processor.to(self.processor.to_q_lora.up.weight.device)
219+
print(
220+
f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}"
221+
)
216222
elif is_custom_diffusion:
217223
processor = CustomDiffusionXFormersAttnProcessor(
218224
train_kv=self.processor.train_kv,
@@ -250,6 +256,7 @@ def set_use_memory_efficient_attention_xformers(
250256
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
251257
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
252258
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
259+
print("Still defaulting to: AttnProcessor2_0 :O")
253260
processor = (
254261
AttnProcessor2_0()
255262
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk

0 commit comments

Comments
 (0)