@@ -61,6 +61,7 @@ def __init__(
61
61
norm_num_groups : Optional [int ] = None ,
62
62
out_bias : bool = True ,
63
63
scale_qk : bool = True ,
64
+ only_cross_attention : bool = False ,
64
65
processor : Optional ["AttnProcessor" ] = None ,
65
66
):
66
67
super ().__init__ ()
@@ -79,6 +80,12 @@ def __init__(
79
80
self .sliceable_head_dim = heads
80
81
81
82
self .added_kv_proj_dim = added_kv_proj_dim
83
+ self .only_cross_attention = only_cross_attention
84
+
85
+ if self .added_kv_proj_dim is None and self .only_cross_attention :
86
+ raise ValueError (
87
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
88
+ )
82
89
83
90
if norm_num_groups is not None :
84
91
self .group_norm = nn .GroupNorm (num_channels = inner_dim , num_groups = norm_num_groups , eps = 1e-5 , affine = True )
@@ -89,8 +96,11 @@ def __init__(
89
96
self .norm_cross = nn .LayerNorm (cross_attention_dim )
90
97
91
98
self .to_q = nn .Linear (query_dim , inner_dim , bias = bias )
92
- self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
93
- self .to_v = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
99
+
100
+ if not self .only_cross_attention :
101
+ # only relevant for the `AddedKVProcessor` classes
102
+ self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
103
+ self .to_v = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
94
104
95
105
if self .added_kv_proj_dim is not None :
96
106
self .add_k_proj = nn .Linear (added_kv_proj_dim , cross_attention_dim )
@@ -409,18 +419,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
409
419
query = attn .to_q (hidden_states )
410
420
query = attn .head_to_batch_dim (query )
411
421
412
- key = attn .to_k (hidden_states )
413
- value = attn .to_v (hidden_states )
414
- key = attn .head_to_batch_dim (key )
415
- value = attn .head_to_batch_dim (value )
416
-
417
422
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
418
423
encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
419
424
encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
420
425
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
421
426
422
- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
423
- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
427
+ if not attn .only_cross_attention :
428
+ key = attn .to_k (hidden_states )
429
+ value = attn .to_v (hidden_states )
430
+ key = attn .head_to_batch_dim (key )
431
+ value = attn .head_to_batch_dim (value )
432
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
433
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
434
+ else :
435
+ key = encoder_hidden_states_key_proj
436
+ value = encoder_hidden_states_value_proj
424
437
425
438
attention_probs = attn .get_attention_scores (query , key , attention_mask )
426
439
hidden_states = torch .bmm (attention_probs , value )
@@ -639,18 +652,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
639
652
dim = query .shape [- 1 ]
640
653
query = attn .head_to_batch_dim (query )
641
654
642
- key = attn .to_k (hidden_states )
643
- value = attn .to_v (hidden_states )
644
655
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
645
656
encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
646
657
647
- key = attn .head_to_batch_dim (key )
648
- value = attn .head_to_batch_dim (value )
649
658
encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
650
659
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
651
660
652
- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
653
- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
661
+ if not attn .only_cross_attention :
662
+ key = attn .to_k (hidden_states )
663
+ value = attn .to_v (hidden_states )
664
+ key = attn .head_to_batch_dim (key )
665
+ value = attn .head_to_batch_dim (value )
666
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
667
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
668
+ else :
669
+ key = encoder_hidden_states_key_proj
670
+ value = encoder_hidden_states_value_proj
654
671
655
672
batch_size_attention , query_tokens , _ = query .shape
656
673
hidden_states = torch .zeros (
0 commit comments