@@ -61,6 +61,7 @@ def __init__(
61
61
norm_num_groups : Optional [int ] = None ,
62
62
out_bias : bool = True ,
63
63
scale_qk : bool = True ,
64
+ only_cross_attention : bool = False ,
64
65
processor : Optional ["AttnProcessor" ] = None ,
65
66
):
66
67
super ().__init__ ()
@@ -79,6 +80,12 @@ def __init__(
79
80
self .sliceable_head_dim = heads
80
81
81
82
self .added_kv_proj_dim = added_kv_proj_dim
83
+ self .only_cross_attention = only_cross_attention
84
+
85
+ if self .added_kv_proj_dim is None and self .only_cross_attention :
86
+ raise ValueError (
87
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
88
+ )
82
89
83
90
if norm_num_groups is not None :
84
91
self .group_norm = nn .GroupNorm (num_channels = query_dim , num_groups = norm_num_groups , eps = 1e-5 , affine = True )
@@ -89,8 +96,11 @@ def __init__(
89
96
self .norm_cross = nn .LayerNorm (cross_attention_dim )
90
97
91
98
self .to_q = nn .Linear (query_dim , inner_dim , bias = bias )
92
- self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
93
- self .to_v = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
99
+
100
+ if not self .only_cross_attention :
101
+ # only relevant for the `AddedKVProcessor` classes
102
+ self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
103
+ self .to_v = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
94
104
95
105
if self .added_kv_proj_dim is not None :
96
106
self .add_k_proj = nn .Linear (added_kv_proj_dim , inner_dim )
@@ -408,18 +418,21 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
408
418
query = attn .to_q (hidden_states )
409
419
query = attn .head_to_batch_dim (query )
410
420
411
- key = attn .to_k (hidden_states )
412
- value = attn .to_v (hidden_states )
413
- key = attn .head_to_batch_dim (key )
414
- value = attn .head_to_batch_dim (value )
415
-
416
421
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
417
422
encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
418
423
encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
419
424
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
420
425
421
- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
422
- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
426
+ if not attn .only_cross_attention :
427
+ key = attn .to_k (hidden_states )
428
+ value = attn .to_v (hidden_states )
429
+ key = attn .head_to_batch_dim (key )
430
+ value = attn .head_to_batch_dim (value )
431
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
432
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
433
+ else :
434
+ key = encoder_hidden_states_key_proj
435
+ value = encoder_hidden_states_value_proj
423
436
424
437
attention_probs = attn .get_attention_scores (query , key , attention_mask )
425
438
hidden_states = torch .bmm (attention_probs , value )
@@ -637,18 +650,22 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
637
650
dim = query .shape [- 1 ]
638
651
query = attn .head_to_batch_dim (query )
639
652
640
- key = attn .to_k (hidden_states )
641
- value = attn .to_v (hidden_states )
642
653
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
643
654
encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
644
655
645
- key = attn .head_to_batch_dim (key )
646
- value = attn .head_to_batch_dim (value )
647
656
encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
648
657
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
649
658
650
- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
651
- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
659
+ if not attn .only_cross_attention :
660
+ key = attn .to_k (hidden_states )
661
+ value = attn .to_v (hidden_states )
662
+ key = attn .head_to_batch_dim (key )
663
+ value = attn .head_to_batch_dim (value )
664
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
665
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
666
+ else :
667
+ key = encoder_hidden_states_key_proj
668
+ value = encoder_hidden_states_value_proj
652
669
653
670
batch_size_attention , query_tokens , _ = query .shape
654
671
hidden_states = torch .zeros (
0 commit comments