Skip to content

Commit c6180a3

Browse files
add only cross attention to simple attention blocks (#3011)
* add only cross attention to simple attention blocks * add test for only_cross_attention re: @patrickvonplaten * mid_block_only_cross_attention better default allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
1 parent e3095c5 commit c6180a3

File tree

5 files changed

+148
-17
lines changed

5 files changed

+148
-17
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
norm_num_groups: Optional[int] = None,
6262
out_bias: bool = True,
6363
scale_qk: bool = True,
64+
only_cross_attention: bool = False,
6465
processor: Optional["AttnProcessor"] = None,
6566
):
6667
super().__init__()
@@ -79,6 +80,12 @@ def __init__(
7980
self.sliceable_head_dim = heads
8081

8182
self.added_kv_proj_dim = added_kv_proj_dim
83+
self.only_cross_attention = only_cross_attention
84+
85+
if self.added_kv_proj_dim is None and self.only_cross_attention:
86+
raise ValueError(
87+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
88+
)
8289

8390
if norm_num_groups is not None:
8491
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
@@ -89,8 +96,14 @@ def __init__(
8996
self.norm_cross = nn.LayerNorm(cross_attention_dim)
9097

9198
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
93-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
99+
100+
if not self.only_cross_attention:
101+
# only relevant for the `AddedKVProcessor` classes
102+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
103+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
104+
else:
105+
self.to_k = None
106+
self.to_v = None
94107

95108
if self.added_kv_proj_dim is not None:
96109
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
@@ -408,18 +421,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
408421
query = attn.to_q(hidden_states)
409422
query = attn.head_to_batch_dim(query)
410423

411-
key = attn.to_k(hidden_states)
412-
value = attn.to_v(hidden_states)
413-
key = attn.head_to_batch_dim(key)
414-
value = attn.head_to_batch_dim(value)
415-
416424
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
417425
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
418426
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
419427
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
420428

421-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
422-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
429+
if not attn.only_cross_attention:
430+
key = attn.to_k(hidden_states)
431+
value = attn.to_v(hidden_states)
432+
key = attn.head_to_batch_dim(key)
433+
value = attn.head_to_batch_dim(value)
434+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
435+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
436+
else:
437+
key = encoder_hidden_states_key_proj
438+
value = encoder_hidden_states_value_proj
423439

424440
attention_probs = attn.get_attention_scores(query, key, attention_mask)
425441
hidden_states = torch.bmm(attention_probs, value)
@@ -637,18 +653,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
637653
dim = query.shape[-1]
638654
query = attn.head_to_batch_dim(query)
639655

640-
key = attn.to_k(hidden_states)
641-
value = attn.to_v(hidden_states)
642656
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
643657
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
644658

645-
key = attn.head_to_batch_dim(key)
646-
value = attn.head_to_batch_dim(value)
647659
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
648660
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
649661

650-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
651-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
662+
if not attn.only_cross_attention:
663+
key = attn.to_k(hidden_states)
664+
value = attn.to_v(hidden_states)
665+
key = attn.head_to_batch_dim(key)
666+
value = attn.head_to_batch_dim(value)
667+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
668+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
669+
else:
670+
key = encoder_hidden_states_key_proj
671+
value = encoder_hidden_states_value_proj
652672

653673
batch_size_attention, query_tokens, _ = query.shape
654674
hidden_states = torch.zeros(

src/diffusers/models/unet_2d_blocks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_down_block(
125125
resnet_time_scale_shift=resnet_time_scale_shift,
126126
skip_time_act=resnet_skip_time_act,
127127
output_scale_factor=resnet_out_scale_factor,
128+
only_cross_attention=only_cross_attention,
128129
)
129130
elif down_block_type == "SkipDownBlock2D":
130131
return SkipDownBlock2D(
@@ -291,6 +292,7 @@ def get_up_block(
291292
resnet_time_scale_shift=resnet_time_scale_shift,
292293
skip_time_act=resnet_skip_time_act,
293294
output_scale_factor=resnet_out_scale_factor,
295+
only_cross_attention=only_cross_attention,
294296
)
295297
elif up_block_type == "AttnUpBlock2D":
296298
return AttnUpBlock2D(
@@ -575,6 +577,7 @@ def __init__(
575577
output_scale_factor=1.0,
576578
cross_attention_dim=1280,
577579
skip_time_act=False,
580+
only_cross_attention=False,
578581
):
579582
super().__init__()
580583

@@ -614,6 +617,7 @@ def __init__(
614617
norm_num_groups=resnet_groups,
615618
bias=True,
616619
upcast_softmax=True,
620+
only_cross_attention=only_cross_attention,
617621
processor=AttnAddedKVProcessor(),
618622
)
619623
)
@@ -1356,6 +1360,7 @@ def __init__(
13561360
output_scale_factor=1.0,
13571361
add_downsample=True,
13581362
skip_time_act=False,
1363+
only_cross_attention=False,
13591364
):
13601365
super().__init__()
13611366

@@ -1394,6 +1399,7 @@ def __init__(
13941399
norm_num_groups=resnet_groups,
13951400
bias=True,
13961401
upcast_softmax=True,
1402+
only_cross_attention=only_cross_attention,
13971403
processor=AttnAddedKVProcessor(),
13981404
)
13991405
)
@@ -2354,6 +2360,7 @@ def __init__(
23542360
output_scale_factor=1.0,
23552361
add_upsample=True,
23562362
skip_time_act=False,
2363+
only_cross_attention=False,
23572364
):
23582365
super().__init__()
23592366
resnets = []
@@ -2393,6 +2400,7 @@ def __init__(
23932400
norm_num_groups=resnet_groups,
23942401
bias=True,
23952402
upcast_softmax=True,
2403+
only_cross_attention=only_cross_attention,
23962404
processor=AttnAddedKVProcessor(),
23972405
)
23982406
)

src/diffusers/models/unet_2d_condition.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ class conditioning with `class_embed_type` equal to `None`.
110110
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
111111
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
112112
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
113-
embeddings with the class embeddings.
113+
embeddings with the class embeddings.
114+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
115+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
116+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
117+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
118+
default to `False`.
114119
"""
115120

116121
_supports_gradient_checkpointing = True
@@ -158,6 +163,7 @@ def __init__(
158163
conv_out_kernel: int = 3,
159164
projection_class_embeddings_input_dim: Optional[int] = None,
160165
class_embeddings_concat: bool = False,
166+
mid_block_only_cross_attention: Optional[bool] = None,
161167
):
162168
super().__init__()
163169

@@ -265,8 +271,14 @@ def __init__(
265271
self.up_blocks = nn.ModuleList([])
266272

267273
if isinstance(only_cross_attention, bool):
274+
if mid_block_only_cross_attention is None:
275+
mid_block_only_cross_attention = only_cross_attention
276+
268277
only_cross_attention = [only_cross_attention] * len(down_block_types)
269278

279+
if mid_block_only_cross_attention is None:
280+
mid_block_only_cross_attention = False
281+
270282
if isinstance(attention_head_dim, int):
271283
attention_head_dim = (attention_head_dim,) * len(down_block_types)
272284

@@ -342,6 +354,7 @@ def __init__(
342354
resnet_groups=norm_num_groups,
343355
resnet_time_scale_shift=resnet_time_scale_shift,
344356
skip_time_act=resnet_skip_time_act,
357+
only_cross_attention=mid_block_only_cross_attention,
345358
)
346359
elif mid_block_type is None:
347360
self.mid_block = None

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ class conditioning with `class_embed_type` equal to `None`.
191191
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
192192
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
193193
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
194-
embeddings with the class embeddings.
194+
embeddings with the class embeddings.
195+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
196+
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
197+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
198+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
199+
default to `False`.
195200
"""
196201

197202
_supports_gradient_checkpointing = True
@@ -244,6 +249,7 @@ def __init__(
244249
conv_out_kernel: int = 3,
245250
projection_class_embeddings_input_dim: Optional[int] = None,
246251
class_embeddings_concat: bool = False,
252+
mid_block_only_cross_attention: Optional[bool] = None,
247253
):
248254
super().__init__()
249255

@@ -357,8 +363,14 @@ def __init__(
357363
self.up_blocks = nn.ModuleList([])
358364

359365
if isinstance(only_cross_attention, bool):
366+
if mid_block_only_cross_attention is None:
367+
mid_block_only_cross_attention = only_cross_attention
368+
360369
only_cross_attention = [only_cross_attention] * len(down_block_types)
361370

371+
if mid_block_only_cross_attention is None:
372+
mid_block_only_cross_attention = False
373+
362374
if isinstance(attention_head_dim, int):
363375
attention_head_dim = (attention_head_dim,) * len(down_block_types)
364376

@@ -434,6 +446,7 @@ def __init__(
434446
resnet_groups=norm_num_groups,
435447
resnet_time_scale_shift=resnet_time_scale_shift,
436448
skip_time_act=resnet_skip_time_act,
449+
only_cross_attention=mid_block_only_cross_attention,
437450
)
438451
elif mid_block_type is None:
439452
self.mid_block = None
@@ -1476,6 +1489,7 @@ def __init__(
14761489
output_scale_factor=1.0,
14771490
cross_attention_dim=1280,
14781491
skip_time_act=False,
1492+
only_cross_attention=False,
14791493
):
14801494
super().__init__()
14811495

@@ -1515,6 +1529,7 @@ def __init__(
15151529
norm_num_groups=resnet_groups,
15161530
bias=True,
15171531
upcast_softmax=True,
1532+
only_cross_attention=only_cross_attention,
15181533
processor=AttnAddedKVProcessor(),
15191534
)
15201535
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
3+
import torch
4+
5+
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
6+
7+
8+
class AttnAddedKVProcessorTests(unittest.TestCase):
9+
def get_constructor_arguments(self, only_cross_attention: bool = False):
10+
query_dim = 10
11+
12+
if only_cross_attention:
13+
cross_attention_dim = 12
14+
else:
15+
# when only cross attention is not set, the cross attention dim must be the same as the query dim
16+
cross_attention_dim = query_dim
17+
18+
return {
19+
"query_dim": query_dim,
20+
"cross_attention_dim": cross_attention_dim,
21+
"heads": 2,
22+
"dim_head": 4,
23+
"added_kv_proj_dim": 6,
24+
"norm_num_groups": 1,
25+
"only_cross_attention": only_cross_attention,
26+
"processor": AttnAddedKVProcessor(),
27+
}
28+
29+
def get_forward_arguments(self, query_dim, added_kv_proj_dim):
30+
batch_size = 2
31+
32+
hidden_states = torch.rand(batch_size, query_dim, 3, 2)
33+
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
34+
attention_mask = None
35+
36+
return {
37+
"hidden_states": hidden_states,
38+
"encoder_hidden_states": encoder_hidden_states,
39+
"attention_mask": attention_mask,
40+
}
41+
42+
def test_only_cross_attention(self):
43+
# self and cross attention
44+
45+
torch.manual_seed(0)
46+
47+
constructor_args = self.get_constructor_arguments(only_cross_attention=False)
48+
attn = Attention(**constructor_args)
49+
50+
self.assertTrue(attn.to_k is not None)
51+
self.assertTrue(attn.to_v is not None)
52+
53+
forward_args = self.get_forward_arguments(
54+
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
55+
)
56+
57+
self_and_cross_attn_out = attn(**forward_args)
58+
59+
# only self attention
60+
61+
torch.manual_seed(0)
62+
63+
constructor_args = self.get_constructor_arguments(only_cross_attention=True)
64+
attn = Attention(**constructor_args)
65+
66+
self.assertTrue(attn.to_k is None)
67+
self.assertTrue(attn.to_v is None)
68+
69+
forward_args = self.get_forward_arguments(
70+
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
71+
)
72+
73+
only_cross_attn_out = attn(**forward_args)
74+
75+
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())

0 commit comments

Comments
 (0)