diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 784eaaa62c55..01432ca37c21 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -227,6 +227,7 @@ def __init__(
self.to_k = None
self.to_v = None
+ self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
@@ -698,12 +699,15 @@ def fuse_projections(self, fuse=True):
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
- self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
- self.to_added_qkv.weight.copy_(concatenated_weights)
- concatenated_bias = torch.cat(
- [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
- self.to_added_qkv.bias.copy_(concatenated_bias)
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
@@ -1274,6 +1278,103 @@ def __call__(
return hidden_states
+class FusedAuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow with fused projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index 89d51969aeaa..f685e690cf81 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -22,7 +22,12 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
+from ..attention_processor import (
+ Attention,
+ AttentionProcessor,
+ AuraFlowAttnProcessor2_0,
+ FusedAuraFlowAttnProcessor2_0,
+)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -320,6 +325,106 @@ def __init__(
self.gradient_checkpointing = False
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index b1ab0ad2b657..5c9c61243c07 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -19,7 +19,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
-from ..attention_processor import AttentionProcessor
+from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -247,6 +247,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index 9a2f1846f1d9..3694a733163c 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -9,7 +9,11 @@
torch_device,
)
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -119,3 +123,43 @@ def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
return
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index a4cc60d12588..a92e99366ee3 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -36,7 +36,12 @@
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+ to_np,
+)
enable_full_determinism()
@@ -308,6 +313,46 @@ def test_inference_with_multiple_images_per_prompt(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
@slow
@require_torch_gpu