|
| 1 | +import unittest |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor |
| 6 | + |
| 7 | + |
| 8 | +class AttnAddedKVProcessorTests(unittest.TestCase): |
| 9 | + def get_constructor_arguments(self, only_cross_attention: bool = False): |
| 10 | + query_dim = 10 |
| 11 | + |
| 12 | + if only_cross_attention: |
| 13 | + cross_attention_dim = 12 |
| 14 | + else: |
| 15 | + # when only cross attention is not set, the cross attention dim must be the same as the query dim |
| 16 | + cross_attention_dim = query_dim |
| 17 | + |
| 18 | + return { |
| 19 | + "query_dim": query_dim, |
| 20 | + "cross_attention_dim": cross_attention_dim, |
| 21 | + "heads": 2, |
| 22 | + "dim_head": 4, |
| 23 | + "added_kv_proj_dim": 6, |
| 24 | + "norm_num_groups": 1, |
| 25 | + "only_cross_attention": only_cross_attention, |
| 26 | + "processor": AttnAddedKVProcessor(), |
| 27 | + } |
| 28 | + |
| 29 | + def get_forward_arguments(self, query_dim, added_kv_proj_dim): |
| 30 | + batch_size = 2 |
| 31 | + |
| 32 | + hidden_states = torch.rand(batch_size, query_dim, 3, 2) |
| 33 | + encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) |
| 34 | + attention_mask = None |
| 35 | + |
| 36 | + return { |
| 37 | + "hidden_states": hidden_states, |
| 38 | + "encoder_hidden_states": encoder_hidden_states, |
| 39 | + "attention_mask": attention_mask, |
| 40 | + } |
| 41 | + |
| 42 | + def test_only_cross_attention(self): |
| 43 | + # self and cross attention |
| 44 | + |
| 45 | + torch.manual_seed(0) |
| 46 | + |
| 47 | + constructor_args = self.get_constructor_arguments(only_cross_attention=False) |
| 48 | + attn = Attention(**constructor_args) |
| 49 | + |
| 50 | + self.assertTrue(attn.to_k is not None) |
| 51 | + self.assertTrue(attn.to_v is not None) |
| 52 | + |
| 53 | + forward_args = self.get_forward_arguments( |
| 54 | + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| 55 | + ) |
| 56 | + |
| 57 | + self_and_cross_attn_out = attn(**forward_args) |
| 58 | + |
| 59 | + # only self attention |
| 60 | + |
| 61 | + torch.manual_seed(0) |
| 62 | + |
| 63 | + constructor_args = self.get_constructor_arguments(only_cross_attention=True) |
| 64 | + attn = Attention(**constructor_args) |
| 65 | + |
| 66 | + self.assertTrue(attn.to_k is None) |
| 67 | + self.assertTrue(attn.to_v is None) |
| 68 | + |
| 69 | + forward_args = self.get_forward_arguments( |
| 70 | + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] |
| 71 | + ) |
| 72 | + |
| 73 | + only_cross_attn_out = attn(**forward_args) |
| 74 | + |
| 75 | + self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) |
0 commit comments