diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cadf793953c6..87bb9e07dc75 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -677,6 +677,21 @@ def fuse_projections(self, fuse=True): concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) + self.to_added_qkv.weight.copy_(concatenated_weights) + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + self.fused_projections = fuse @@ -1708,6 +1723,109 @@ def __call__( return hidden_states +class FusedHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 6538522756f9..161770c67cf8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -26,6 +26,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin @@ -492,6 +493,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 25eb6384c68c..305401164b2f 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import FromOriginalModelMixin, PeftAdapterMixin from ..models.attention import JointTransformerBlock -from ..models.attention_processor import Attention, AttentionProcessor +from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -196,7 +196,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -220,6 +220,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedJointAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 354acfebe0a2..0fa21755f09c 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -29,6 +29,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from .controlnet import ControlNetConditioningEmbedding from .embeddings import TimestepEmbedding, Timesteps @@ -1001,6 +1002,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index cc0dcbd79e9f..7f3dab220aaa 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -20,7 +20,7 @@ from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, @@ -317,7 +317,7 @@ def __init__( self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -341,6 +341,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index cb1165e29b1c..a02c7a471f3a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import JointTransformerBlock -from ...models.attention_processor import Attention, AttentionProcessor +from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -211,7 +211,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -235,6 +235,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedJointAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 2b9122799bf3..611ac6087e4a 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -30,6 +30,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import ( GaussianFourierProjection, @@ -890,6 +891,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 40b3b92427ce..3081fdc4700c 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -31,6 +31,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -532,6 +533,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index b650f0e21af0..6ab3a577b892 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -29,6 +29,7 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -498,6 +499,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7e5209afba0c..196f947d599b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -29,6 +29,7 @@ AttnAddedKVProcessor, AttnProcessor, AttnProcessor2_0, + FusedAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) @@ -929,6 +930,8 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) + self.set_attn_processor(FusedAttnProcessor2_0()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index ad5f5f3ef2ca..653cb41e4bc4 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -36,7 +36,12 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) enable_full_determinism() @@ -261,6 +266,16 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) inputs["return_dict"] = False image_fused = pipe(**inputs)[0] diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 93e740145477..75a7d88ea4f2 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -13,7 +13,11 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -191,7 +195,16 @@ def test_fused_qkv_projections(self): image = pipe(**inputs).images original_image_slice = image[0, -3:, -3:, -1] + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 1898149e43f4..06fcc1c90b71 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -13,6 +13,7 @@ import numpy as np import PIL.Image import torch +import torch.nn as nn from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -40,7 +41,12 @@ from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available -from diffusers.utils.testing_utils import CaptureLogger, require_torch, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + CaptureLogger, + require_torch, + skip_mps, + torch_device, +) from ..models.autoencoders.test_models_vae import ( get_asym_autoencoder_kl_config, @@ -67,6 +73,17 @@ def check_same_shape(tensor_list): return all(shape == shapes[0] for shape in shapes[1:]) +def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors): + current_attn_processors = model.attn_processors + return len(current_attn_processors) == len(original_attn_processors) + + +def check_qkv_fusion_processors_exist(model): + current_attn_processors = model.attn_processors + proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()] + return all(p.startswith("Fused") for p in proc_names) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -196,6 +213,19 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.fuse_qkv_projections() + for _, component in pipe.components.items(): + if ( + isinstance(component, nn.Module) + and hasattr(component, "original_attn_processors") + and component.original_attn_processors is not None + ): + assert check_qkv_fusion_processors_exist( + component + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + component, component.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + inputs = self.get_dummy_inputs(device) inputs["return_dict"] = False image_fused = pipe(**inputs)[0]