From c34001103db574ebcf60719c112c93b400e7c0f8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 17 May 2023 09:58:43 +0530 Subject: [PATCH 1/2] add: debugging to enabling memory efficient processing --- src/diffusers/models/attention_processor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f88400da0333..88e045805624 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -213,6 +213,9 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + print( + f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}" + ) elif is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, @@ -250,6 +253,7 @@ def set_use_memory_efficient_attention_xformers( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + print("Still defaulting to: AttnProcessor2_0 :O") processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk From 3e3bd819178af15bdf9784ae5323cf7cc7382de9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 17 May 2023 10:09:24 +0530 Subject: [PATCH 2/2] add: better warning message. --- src/diffusers/models/attention_processor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 88e045805624..2f58176cabc5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -191,7 +191,10 @@ def set_use_memory_efficient_attention_xformers( elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: warnings.warn( "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " - "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0." + "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " + "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " + "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " + "native efficient flash attention." ) else: try: