@@ -191,7 +191,10 @@ def set_use_memory_efficient_attention_xformers(
191
191
elif hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
192
192
warnings .warn (
193
193
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
194
- "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
194
+ "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
195
+ "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
196
+ "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
197
+ "native efficient flash attention."
195
198
)
196
199
else :
197
200
try :
@@ -213,6 +216,9 @@ def set_use_memory_efficient_attention_xformers(
213
216
)
214
217
processor .load_state_dict (self .processor .state_dict ())
215
218
processor .to (self .processor .to_q_lora .up .weight .device )
219
+ print (
220
+ f"is_lora is set to { is_lora } , type: LoRAXFormersAttnProcessor: { isinstance (processor , LoRAXFormersAttnProcessor )} "
221
+ )
216
222
elif is_custom_diffusion :
217
223
processor = CustomDiffusionXFormersAttnProcessor (
218
224
train_kv = self .processor .train_kv ,
@@ -250,6 +256,7 @@ def set_use_memory_efficient_attention_xformers(
250
256
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
251
257
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
252
258
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
259
+ print ("Still defaulting to: AttnProcessor2_0 :O" )
253
260
processor = (
254
261
AttnProcessor2_0 ()
255
262
if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk
0 commit comments