diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 04ead2adcf6e..864b042c245a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -61,6 +61,7 @@ def __init__( norm_num_groups: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, + only_cross_attention: bool = False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -79,6 +80,12 @@ def __init__( self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) @@ -89,8 +96,14 @@ def __init__( self.norm_cross = nn.LayerNorm(cross_attention_dim) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) @@ -408,18 +421,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) @@ -637,18 +653,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, dim = query.shape[-1] query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0aeca6f508d0..540059b10713 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -125,6 +125,7 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -291,6 +292,7 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -575,6 +577,7 @@ def __init__( output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, + only_cross_attention=False, ): super().__init__() @@ -614,6 +617,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + only_cross_attention=only_cross_attention, processor=AttnAddedKVProcessor(), ) ) @@ -1356,6 +1360,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, skip_time_act=False, + only_cross_attention=False, ): super().__init__() @@ -1394,6 +1399,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + only_cross_attention=only_cross_attention, processor=AttnAddedKVProcessor(), ) ) @@ -2354,6 +2360,7 @@ def __init__( output_scale_factor=1.0, add_upsample=True, skip_time_act=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -2393,6 +2400,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + only_cross_attention=only_cross_attention, processor=AttnAddedKVProcessor(), ) ) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3610231d19e6..3fb4202ed119 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -110,7 +110,12 @@ class conditioning with `class_embed_type` equal to `None`. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the + `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will + default to `False`. """ _supports_gradient_checkpointing = True @@ -158,6 +163,7 @@ def __init__( conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, ): super().__init__() @@ -265,8 +271,14 @@ def __init__( self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -342,6 +354,7 @@ def __init__( resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, ) elif mid_block_type is None: self.mid_block = None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 6a3635613104..51d1c62c926b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -191,7 +191,12 @@ class conditioning with `class_embed_type` equal to `None`. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the + `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will + default to `False`. """ _supports_gradient_checkpointing = True @@ -244,6 +249,7 @@ def __init__( conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, ): super().__init__() @@ -357,8 +363,14 @@ def __init__( self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -434,6 +446,7 @@ def __init__( resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, ) elif mid_block_type is None: self.mid_block = None @@ -1476,6 +1489,7 @@ def __init__( output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, + only_cross_attention=False, ): super().__init__() @@ -1515,6 +1529,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + only_cross_attention=only_cross_attention, processor=AttnAddedKVProcessor(), ) ) diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py new file mode 100644 index 000000000000..172d6d4d91fc --- /dev/null +++ b/tests/models/test_attention_processor.py @@ -0,0 +1,75 @@ +import unittest + +import torch + +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor + + +class AttnAddedKVProcessorTests(unittest.TestCase): + def get_constructor_arguments(self, only_cross_attention: bool = False): + query_dim = 10 + + if only_cross_attention: + cross_attention_dim = 12 + else: + # when only cross attention is not set, the cross attention dim must be the same as the query dim + cross_attention_dim = query_dim + + return { + "query_dim": query_dim, + "cross_attention_dim": cross_attention_dim, + "heads": 2, + "dim_head": 4, + "added_kv_proj_dim": 6, + "norm_num_groups": 1, + "only_cross_attention": only_cross_attention, + "processor": AttnAddedKVProcessor(), + } + + def get_forward_arguments(self, query_dim, added_kv_proj_dim): + batch_size = 2 + + hidden_states = torch.rand(batch_size, query_dim, 3, 2) + encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) + attention_mask = None + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + } + + def test_only_cross_attention(self): + # self and cross attention + + torch.manual_seed(0) + + constructor_args = self.get_constructor_arguments(only_cross_attention=False) + attn = Attention(**constructor_args) + + self.assertTrue(attn.to_k is not None) + self.assertTrue(attn.to_v is not None) + + forward_args = self.get_forward_arguments( + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] + ) + + self_and_cross_attn_out = attn(**forward_args) + + # only self attention + + torch.manual_seed(0) + + constructor_args = self.get_constructor_arguments(only_cross_attention=True) + attn = Attention(**constructor_args) + + self.assertTrue(attn.to_k is None) + self.assertTrue(attn.to_v is None) + + forward_args = self.get_forward_arguments( + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] + ) + + only_cross_attn_out = attn(**forward_args) + + self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())