diff --git a/docs/source/en/api/attnprocessor.mdx b/docs/source/en/api/attnprocessor.mdx index ead639feffe0..7a4812e0961e 100644 --- a/docs/source/en/api/attnprocessor.mdx +++ b/docs/source/en/api/attnprocessor.mdx @@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech ## LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor +## LoRAAttnProcessor2_0 +[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0 + ## CustomDiffusionAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index ca25152fcb1c..3accc4265787 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -55,6 +55,7 @@ AttnAddedKVProcessor2_0, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAAttnProcessor2_0, SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler @@ -844,8 +845,9 @@ def main(args): if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): lora_attn_processor_class = LoRAAttnAddedKVProcessor else: - lora_attn_processor_class = LoRAAttnProcessor - + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ab0f1418e615..684a2ba710b9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union import torch +import torch.nn.functional as F from huggingface_hub import hf_hub_download from .models.attention_processor import ( @@ -27,6 +28,7 @@ CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, @@ -287,7 +289,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): attn_processor_class = LoRAXFormersAttnProcessor else: - attn_processor_class = LoRAAttnProcessor + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, @@ -927,11 +931,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] - logger.info(f"Loading {self.text_encoder_name}.") text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {self.text_encoder_name}.") attn_procs_text_encoder = self._load_text_encoder_attn_procs( text_encoder_lora_state_dict, network_alpha=network_alpha ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 61a1faea07f4..e0404a83cc9a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Callable, Optional, Union import torch @@ -166,7 +165,8 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) + self.processor, + (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) @@ -200,14 +200,6 @@ def set_use_memory_efficient_attention_xformers( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) - elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: - warnings.warn( - "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " - "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " - "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " - "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " - "native efficient flash attention." - ) else: try: # Make sure we can run the memory efficient attention @@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers( raise e if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, @@ -252,7 +246,10 @@ def set_use_memory_efficient_attention_xformers( processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: - processor = LoRAAttnProcessor( + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, @@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): @@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. + """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): @@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module): [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ def __init__( @@ -1236,6 +1239,97 @@ def __call__( return hidden_states +class LoRAAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product + attention. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. @@ -1520,6 +1614,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 52826fc0c736..2b10955d23f2 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -28,6 +29,7 @@ AttnProcessor, AttnProcessor2_0, LoRAAttnProcessor, + LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) @@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) unet_lora_layers = AttnProcsLayers(lora_attn_procs) return lora_attn_procs, unet_lora_layers def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) for name, module in text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = LoRAAttnProcessor( + text_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=module.out_proj.out_features, cross_attention_dim=None ) return text_lora_attn_procs @@ -368,7 +378,10 @@ def test_lora_unet_attn_processors(self): # check if lora attention processors are used for _, module in sd_pipe.unet.named_modules(): if isinstance(module, Attention): - self.assertIsInstance(module.processor, LoRAAttnProcessor) + attn_proc_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + self.assertIsInstance(module.processor, attn_proc_class) @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_lora_unet_attn_processors_with_xformers(self): diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 928f6bcbe960..762c4975da51 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -261,7 +261,7 @@ def test_lora_save_load(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 5e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 @@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 3e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4