Skip to content

add only cross attention to simple attention blocks #3011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
Expand All @@ -79,6 +80,12 @@ def __init__(
self.sliceable_head_dim = heads

self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention

if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)

if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
Expand All @@ -89,8 +96,14 @@ def __init__(
self.norm_cross = nn.LayerNorm(cross_attention_dim)

self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)

if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None

if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
Expand Down Expand Up @@ -408,18 +421,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)

key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
Expand Down Expand Up @@ -637,18 +653,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)

key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
Expand Down Expand Up @@ -291,6 +292,7 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
Expand Down Expand Up @@ -575,6 +577,7 @@ def __init__(
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()

Expand Down Expand Up @@ -614,6 +617,7 @@ def __init__(
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
Expand Down Expand Up @@ -1356,6 +1360,7 @@ def __init__(
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()

Expand Down Expand Up @@ -1394,6 +1399,7 @@ def __init__(
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
Expand Down Expand Up @@ -2354,6 +2360,7 @@ def __init__(
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -2393,6 +2400,7 @@ def __init__(
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ class conditioning with `class_embed_type` equal to `None`.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
default to `False`.
"""

_supports_gradient_checkpointing = True
Expand Down Expand Up @@ -158,6 +163,7 @@ def __init__(
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
):
super().__init__()

Expand Down Expand Up @@ -265,8 +271,14 @@ def __init__(
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention

only_cross_attention = [only_cross_attention] * len(down_block_types)

if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand Down Expand Up @@ -342,6 +354,7 @@ def __init__(
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten you had used only_cross_attention[-1] for this argument but I think it makes sense to use a separate config for the mid block. The other arguments for the mid block that re-use the config for the last encoder block make sense because they are dimensionality based and they have to match. But this constraint doesn't hold necessarily for the only_cross_attention flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually with this commit, 7329ead , we can do a better default to the value of only_cross_attention when it's given as a single boolean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me

)
elif mid_block_type is None:
self.mid_block = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ class conditioning with `class_embed_type` equal to `None`.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
default to `False`.
"""

_supports_gradient_checkpointing = True
Expand Down Expand Up @@ -244,6 +249,7 @@ def __init__(
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
):
super().__init__()

Expand Down Expand Up @@ -357,8 +363,14 @@ def __init__(
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention

only_cross_attention = [only_cross_attention] * len(down_block_types)

if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand Down Expand Up @@ -434,6 +446,7 @@ def __init__(
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
)
elif mid_block_type is None:
self.mid_block = None
Expand Down Expand Up @@ -1476,6 +1489,7 @@ def __init__(
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
):
super().__init__()

Expand Down Expand Up @@ -1515,6 +1529,7 @@ def __init__(
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
processor=AttnAddedKVProcessor(),
)
)
Expand Down
75 changes: 75 additions & 0 deletions tests/models/test_attention_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest

import torch

from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor


class AttnAddedKVProcessorTests(unittest.TestCase):
def get_constructor_arguments(self, only_cross_attention: bool = False):
query_dim = 10

if only_cross_attention:
cross_attention_dim = 12
else:
# when only cross attention is not set, the cross attention dim must be the same as the query dim
cross_attention_dim = query_dim

return {
"query_dim": query_dim,
"cross_attention_dim": cross_attention_dim,
"heads": 2,
"dim_head": 4,
"added_kv_proj_dim": 6,
"norm_num_groups": 1,
"only_cross_attention": only_cross_attention,
"processor": AttnAddedKVProcessor(),
}

def get_forward_arguments(self, query_dim, added_kv_proj_dim):
batch_size = 2

hidden_states = torch.rand(batch_size, query_dim, 3, 2)
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
attention_mask = None

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
}

def test_only_cross_attention(self):
# self and cross attention

torch.manual_seed(0)

constructor_args = self.get_constructor_arguments(only_cross_attention=False)
attn = Attention(**constructor_args)

self.assertTrue(attn.to_k is not None)
self.assertTrue(attn.to_v is not None)

forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)

self_and_cross_attn_out = attn(**forward_args)

# only self attention

torch.manual_seed(0)

constructor_args = self.get_constructor_arguments(only_cross_attention=True)
attn = Attention(**constructor_args)

self.assertTrue(attn.to_k is None)
self.assertTrue(attn.to_v is None)

forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)

only_cross_attn_out = attn(**forward_args)

self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sufficient for merging IMO.