Skip to content

Commit c2fbe8e

Browse files
committed
add only cross attention to simple attention blocks
1 parent 8c6b47c commit c2fbe8e

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
norm_num_groups: Optional[int] = None,
6262
out_bias: bool = True,
6363
scale_qk: bool = True,
64+
only_cross_attention: bool = False,
6465
processor: Optional["AttnProcessor"] = None,
6566
):
6667
super().__init__()
@@ -79,6 +80,12 @@ def __init__(
7980
self.sliceable_head_dim = heads
8081

8182
self.added_kv_proj_dim = added_kv_proj_dim
83+
self.only_cross_attention = only_cross_attention
84+
85+
if self.added_kv_proj_dim is None and self.only_cross_attention:
86+
raise ValueError(
87+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
88+
)
8289

8390
if norm_num_groups is not None:
8491
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
@@ -89,8 +96,11 @@ def __init__(
8996
self.norm_cross = nn.LayerNorm(cross_attention_dim)
9097

9198
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
93-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
99+
100+
if not self.only_cross_attention:
101+
# only relevant for the `AddedKVProcessor` classes
102+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
103+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
94104

95105
if self.added_kv_proj_dim is not None:
96106
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
@@ -408,18 +418,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
408418
query = attn.to_q(hidden_states)
409419
query = attn.head_to_batch_dim(query)
410420

411-
key = attn.to_k(hidden_states)
412-
value = attn.to_v(hidden_states)
413-
key = attn.head_to_batch_dim(key)
414-
value = attn.head_to_batch_dim(value)
415-
416421
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
417422
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
418423
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
419424
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
420425

421-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
422-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
426+
if not attn.only_cross_attention:
427+
key = attn.to_k(hidden_states)
428+
value = attn.to_v(hidden_states)
429+
key = attn.head_to_batch_dim(key)
430+
value = attn.head_to_batch_dim(value)
431+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
432+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
433+
else:
434+
key = encoder_hidden_states_key_proj
435+
value = encoder_hidden_states_value_proj
423436

424437
attention_probs = attn.get_attention_scores(query, key, attention_mask)
425438
hidden_states = torch.bmm(attention_probs, value)
@@ -637,18 +650,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
637650
dim = query.shape[-1]
638651
query = attn.head_to_batch_dim(query)
639652

640-
key = attn.to_k(hidden_states)
641-
value = attn.to_v(hidden_states)
642653
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
643654
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
644655

645-
key = attn.head_to_batch_dim(key)
646-
value = attn.head_to_batch_dim(value)
647656
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
648657
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
649658

650-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
651-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
659+
if not attn.only_cross_attention:
660+
key = attn.to_k(hidden_states)
661+
value = attn.to_v(hidden_states)
662+
key = attn.head_to_batch_dim(key)
663+
value = attn.head_to_batch_dim(value)
664+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
665+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
666+
else:
667+
key = encoder_hidden_states_key_proj
668+
value = encoder_hidden_states_value_proj
652669

653670
batch_size_attention, query_tokens, _ = query.shape
654671
hidden_states = torch.zeros(

src/diffusers/models/unet_2d_blocks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_down_block(
125125
resnet_time_scale_shift=resnet_time_scale_shift,
126126
skip_time_act=resnet_skip_time_act,
127127
output_scale_factor=resnet_out_scale_factor,
128+
only_cross_attention=only_cross_attention,
128129
)
129130
elif down_block_type == "SkipDownBlock2D":
130131
return SkipDownBlock2D(
@@ -291,6 +292,7 @@ def get_up_block(
291292
resnet_time_scale_shift=resnet_time_scale_shift,
292293
skip_time_act=resnet_skip_time_act,
293294
output_scale_factor=resnet_out_scale_factor,
295+
only_cross_attention=only_cross_attention,
294296
)
295297
elif up_block_type == "AttnUpBlock2D":
296298
return AttnUpBlock2D(
@@ -575,6 +577,7 @@ def __init__(
575577
output_scale_factor=1.0,
576578
cross_attention_dim=1280,
577579
skip_time_act=False,
580+
only_cross_attention=False,
578581
):
579582
super().__init__()
580583

@@ -614,6 +617,7 @@ def __init__(
614617
norm_num_groups=resnet_groups,
615618
bias=True,
616619
upcast_softmax=True,
620+
only_cross_attention=only_cross_attention,
617621
processor=AttnAddedKVProcessor(),
618622
)
619623
)
@@ -1356,6 +1360,7 @@ def __init__(
13561360
output_scale_factor=1.0,
13571361
add_downsample=True,
13581362
skip_time_act=False,
1363+
only_cross_attention=False,
13591364
):
13601365
super().__init__()
13611366

@@ -1394,6 +1399,7 @@ def __init__(
13941399
norm_num_groups=resnet_groups,
13951400
bias=True,
13961401
upcast_softmax=True,
1402+
only_cross_attention=only_cross_attention,
13971403
processor=AttnAddedKVProcessor(),
13981404
)
13991405
)
@@ -2354,6 +2360,7 @@ def __init__(
23542360
output_scale_factor=1.0,
23552361
add_upsample=True,
23562362
skip_time_act=False,
2363+
only_cross_attention=False,
23572364
):
23582365
super().__init__()
23592366
resnets = []
@@ -2393,6 +2400,7 @@ def __init__(
23932400
norm_num_groups=resnet_groups,
23942401
bias=True,
23952402
upcast_softmax=True,
2403+
only_cross_attention=only_cross_attention,
23962404
processor=AttnAddedKVProcessor(),
23972405
)
23982406
)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
conv_out_kernel: int = 3,
159159
projection_class_embeddings_input_dim: Optional[int] = None,
160160
class_embeddings_concat: bool = False,
161+
mid_block_only_cross_attention: bool = False,
161162
):
162163
super().__init__()
163164

@@ -342,6 +343,7 @@ def __init__(
342343
resnet_groups=norm_num_groups,
343344
resnet_time_scale_shift=resnet_time_scale_shift,
344345
skip_time_act=resnet_skip_time_act,
346+
only_cross_attention=mid_block_only_cross_attention,
345347
)
346348
elif mid_block_type is None:
347349
self.mid_block = None

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def __init__(
244244
conv_out_kernel: int = 3,
245245
projection_class_embeddings_input_dim: Optional[int] = None,
246246
class_embeddings_concat: bool = False,
247+
mid_block_only_cross_attention: bool = False,
247248
):
248249
super().__init__()
249250

@@ -434,6 +435,7 @@ def __init__(
434435
resnet_groups=norm_num_groups,
435436
resnet_time_scale_shift=resnet_time_scale_shift,
436437
skip_time_act=resnet_skip_time_act,
438+
only_cross_attention=mid_block_only_cross_attention,
437439
)
438440
elif mid_block_type is None:
439441
self.mid_block = None
@@ -1476,6 +1478,7 @@ def __init__(
14761478
output_scale_factor=1.0,
14771479
cross_attention_dim=1280,
14781480
skip_time_act=False,
1481+
only_cross_attention=False,
14791482
):
14801483
super().__init__()
14811484

@@ -1515,6 +1518,7 @@ def __init__(
15151518
norm_num_groups=resnet_groups,
15161519
bias=True,
15171520
upcast_softmax=True,
1521+
only_cross_attention=only_cross_attention,
15181522
processor=AttnAddedKVProcessor(),
15191523
)
15201524
)

0 commit comments

Comments
 (0)