Skip to content

[Core] fix QKV fusion for attention #8829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,21 @@ def fuse_projections(self, fuse=True):
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)

# handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
self.to_added_qkv.weight.copy_(concatenated_weights)
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)

self.fused_projections = fuse


Expand Down Expand Up @@ -1708,6 +1723,109 @@ def __call__(
return hidden_states


class FusedHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
query and key vector.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)

kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
Expand Down Expand Up @@ -482,9 +483,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..models.attention import JointTransformerBlock
from ..models.attention_processor import Attention, AttentionProcessor
from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..models.modeling_outputs import Transformer2DModelOutput
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
Expand Down Expand Up @@ -196,7 +196,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
Expand All @@ -216,9 +216,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedJointAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps
Expand Down Expand Up @@ -997,9 +998,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/models/transformers/hunyuan_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
Expand All @@ -337,9 +337,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedHunyuanAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import JointTransformerBlock
from ...models.attention_processor import Attention, AttentionProcessor
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
Expand Down Expand Up @@ -211,7 +211,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
Expand All @@ -231,9 +231,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedJointAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import (
GaussianFourierProjection,
Expand Down Expand Up @@ -886,9 +887,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedAttnProcessor2_0())

def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
Expand Down Expand Up @@ -528,9 +529,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
Expand Down Expand Up @@ -494,9 +495,16 @@ def fuse_qkv_projections(self):

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def fuse_recursively(module):
for submodule in module.children():
if isinstance(submodule, Attention):
submodule.fuse_projections(fuse=True)
# Recursively call this function on the submodule to handle nesting
fuse_recursively(submodule)

fuse_recursively(self)

self.set_attn_processor(FusedAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down
Loading
Loading