Skip to content

Commit a52fd6b

Browse files
committed
add test for only_cross_attention re: @patrickvonplaten
1 parent c2fbe8e commit a52fd6b

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def __init__(
101101
# only relevant for the `AddedKVProcessor` classes
102102
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
103103
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
104+
else:
105+
self.to_k = None
106+
self.to_v = None
104107

105108
if self.added_kv_proj_dim is not None:
106109
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
3+
import torch
4+
5+
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
6+
7+
8+
class AttnAddedKVProcessorTests(unittest.TestCase):
9+
def get_constructor_arguments(self, only_cross_attention: bool = False):
10+
query_dim = 10
11+
12+
if only_cross_attention:
13+
cross_attention_dim = 12
14+
else:
15+
# when only cross attention is not set, the cross attention dim must be the same as the query dim
16+
cross_attention_dim = query_dim
17+
18+
return {
19+
"query_dim": query_dim,
20+
"cross_attention_dim": cross_attention_dim,
21+
"heads": 2,
22+
"dim_head": 4,
23+
"added_kv_proj_dim": 6,
24+
"norm_num_groups": 1,
25+
"only_cross_attention": only_cross_attention,
26+
"processor": AttnAddedKVProcessor(),
27+
}
28+
29+
def get_forward_arguments(self, query_dim, added_kv_proj_dim):
30+
batch_size = 2
31+
32+
hidden_states = torch.rand(batch_size, query_dim, 3, 2)
33+
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
34+
attention_mask = None
35+
36+
return {
37+
"hidden_states": hidden_states,
38+
"encoder_hidden_states": encoder_hidden_states,
39+
"attention_mask": attention_mask,
40+
}
41+
42+
def test_only_cross_attention(self):
43+
# self and cross attention
44+
45+
torch.manual_seed(0)
46+
47+
constructor_args = self.get_constructor_arguments(only_cross_attention=False)
48+
attn = Attention(**constructor_args)
49+
50+
self.assertTrue(attn.to_k is not None)
51+
self.assertTrue(attn.to_v is not None)
52+
53+
forward_args = self.get_forward_arguments(
54+
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
55+
)
56+
57+
self_and_cross_attn_out = attn(**forward_args)
58+
59+
# only self attention
60+
61+
torch.manual_seed(0)
62+
63+
constructor_args = self.get_constructor_arguments(only_cross_attention=True)
64+
attn = Attention(**constructor_args)
65+
66+
self.assertTrue(attn.to_k is None)
67+
self.assertTrue(attn.to_v is None)
68+
69+
forward_args = self.get_forward_arguments(
70+
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
71+
)
72+
73+
only_cross_attn_out = attn(**forward_args)
74+
75+
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())

0 commit comments

Comments
 (0)