Skip to content

Commit ad99a33

Browse files
committed
add only cross attention to simple attention blocks
1 parent ce144d6 commit ad99a33

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
norm_num_groups: Optional[int] = None,
6262
out_bias: bool = True,
6363
scale_qk: bool = True,
64+
only_cross_attention: bool = False,
6465
processor: Optional["AttnProcessor"] = None,
6566
):
6667
super().__init__()
@@ -79,6 +80,12 @@ def __init__(
7980
self.sliceable_head_dim = heads
8081

8182
self.added_kv_proj_dim = added_kv_proj_dim
83+
self.only_cross_attention = only_cross_attention
84+
85+
if self.added_kv_proj_dim is None and self.only_cross_attention:
86+
raise ValueError(
87+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
88+
)
8289

8390
if norm_num_groups is not None:
8491
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
@@ -89,8 +96,11 @@ def __init__(
8996
self.norm_cross = nn.LayerNorm(cross_attention_dim)
9097

9198
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
93-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
99+
100+
if not self.only_cross_attention:
101+
# only relevant for the `AddedKVProcessor` classes
102+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
103+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
94104

95105
if self.added_kv_proj_dim is not None:
96106
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
@@ -409,18 +419,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
409419
query = attn.to_q(hidden_states)
410420
query = attn.head_to_batch_dim(query)
411421

412-
key = attn.to_k(hidden_states)
413-
value = attn.to_v(hidden_states)
414-
key = attn.head_to_batch_dim(key)
415-
value = attn.head_to_batch_dim(value)
416-
417422
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
418423
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
419424
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
420425
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
421426

422-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
423-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
427+
if not attn.only_cross_attention:
428+
key = attn.to_k(hidden_states)
429+
value = attn.to_v(hidden_states)
430+
key = attn.head_to_batch_dim(key)
431+
value = attn.head_to_batch_dim(value)
432+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
433+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
434+
else:
435+
key = encoder_hidden_states_key_proj
436+
value = encoder_hidden_states_value_proj
424437

425438
attention_probs = attn.get_attention_scores(query, key, attention_mask)
426439
hidden_states = torch.bmm(attention_probs, value)
@@ -639,18 +652,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
639652
dim = query.shape[-1]
640653
query = attn.head_to_batch_dim(query)
641654

642-
key = attn.to_k(hidden_states)
643-
value = attn.to_v(hidden_states)
644655
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
645656
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
646657

647-
key = attn.head_to_batch_dim(key)
648-
value = attn.head_to_batch_dim(value)
649658
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
650659
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
651660

652-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
653-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
661+
if not attn.only_cross_attention:
662+
key = attn.to_k(hidden_states)
663+
value = attn.to_v(hidden_states)
664+
key = attn.head_to_batch_dim(key)
665+
value = attn.head_to_batch_dim(value)
666+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
667+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
668+
else:
669+
key = encoder_hidden_states_key_proj
670+
value = encoder_hidden_states_value_proj
654671

655672
batch_size_attention, query_tokens, _ = query.shape
656673
hidden_states = torch.zeros(

src/diffusers/models/unet_2d_blocks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def get_down_block(
119119
cross_attention_dim=cross_attention_dim,
120120
attn_num_head_channels=attn_num_head_channels,
121121
resnet_time_scale_shift=resnet_time_scale_shift,
122+
only_cross_attention=only_cross_attention,
122123
)
123124
elif down_block_type == "SkipDownBlock2D":
124125
return SkipDownBlock2D(
@@ -279,6 +280,7 @@ def get_up_block(
279280
cross_attention_dim=cross_attention_dim,
280281
attn_num_head_channels=attn_num_head_channels,
281282
resnet_time_scale_shift=resnet_time_scale_shift,
283+
only_cross_attention=only_cross_attention,
282284
)
283285
elif up_block_type == "AttnUpBlock2D":
284286
return AttnUpBlock2D(
@@ -562,6 +564,7 @@ def __init__(
562564
attn_num_head_channels=1,
563565
output_scale_factor=1.0,
564566
cross_attention_dim=1280,
567+
only_cross_attention=False,
565568
):
566569
super().__init__()
567570

@@ -600,6 +603,7 @@ def __init__(
600603
norm_num_groups=resnet_groups,
601604
bias=True,
602605
upcast_softmax=True,
606+
only_cross_attention=only_cross_attention,
603607
processor=AttnAddedKVProcessor(),
604608
)
605609
)
@@ -1337,6 +1341,7 @@ def __init__(
13371341
cross_attention_dim=1280,
13381342
output_scale_factor=1.0,
13391343
add_downsample=True,
1344+
only_cross_attention=False,
13401345
):
13411346
super().__init__()
13421347

@@ -1374,6 +1379,7 @@ def __init__(
13741379
norm_num_groups=resnet_groups,
13751380
bias=True,
13761381
upcast_softmax=True,
1382+
only_cross_attention=only_cross_attention,
13771383
processor=AttnAddedKVProcessor(),
13781384
)
13791385
)
@@ -2329,6 +2335,7 @@ def __init__(
23292335
cross_attention_dim=1280,
23302336
output_scale_factor=1.0,
23312337
add_upsample=True,
2338+
only_cross_attention=False,
23322339
):
23332340
super().__init__()
23342341
resnets = []
@@ -2367,6 +2374,7 @@ def __init__(
23672374
norm_num_groups=resnet_groups,
23682375
bias=True,
23692376
upcast_softmax=True,
2377+
only_cross_attention=only_cross_attention,
23702378
processor=AttnAddedKVProcessor(),
23712379
)
23722380
)

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
conv_out_kernel: int = 3,
154154
projection_class_embeddings_input_dim: Optional[int] = None,
155155
class_embeddings_concat: bool = False,
156+
mid_block_only_cross_attention: bool = False,
156157
):
157158
super().__init__()
158159

@@ -321,6 +322,7 @@ def __init__(
321322
attn_num_head_channels=attention_head_dim[-1],
322323
resnet_groups=norm_num_groups,
323324
resnet_time_scale_shift=resnet_time_scale_shift,
325+
only_cross_attention=mid_block_only_cross_attention,
324326
)
325327
elif mid_block_type is None:
326328
self.mid_block = None

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def __init__(
239239
conv_out_kernel: int = 3,
240240
projection_class_embeddings_input_dim: Optional[int] = None,
241241
class_embeddings_concat: bool = False,
242+
mid_block_only_cross_attention: bool = False,
242243
):
243244
super().__init__()
244245

@@ -412,6 +413,7 @@ def __init__(
412413
attn_num_head_channels=attention_head_dim[-1],
413414
resnet_groups=norm_num_groups,
414415
resnet_time_scale_shift=resnet_time_scale_shift,
416+
only_cross_attention=mid_block_only_cross_attention,
415417
)
416418
elif mid_block_type is None:
417419
self.mid_block = None
@@ -1434,6 +1436,7 @@ def __init__(
14341436
attn_num_head_channels=1,
14351437
output_scale_factor=1.0,
14361438
cross_attention_dim=1280,
1439+
only_cross_attention=False,
14371440
):
14381441
super().__init__()
14391442

@@ -1472,6 +1475,7 @@ def __init__(
14721475
norm_num_groups=resnet_groups,
14731476
bias=True,
14741477
upcast_softmax=True,
1478+
only_cross_attention=only_cross_attention,
14751479
processor=AttnAddedKVProcessor(),
14761480
)
14771481
)

0 commit comments

Comments
 (0)