@@ -255,11 +255,15 @@ def batch_to_head_dim(self, tensor):
255
255
tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
256
256
return tensor
257
257
258
- def head_to_batch_dim (self , tensor ):
258
+ def head_to_batch_dim (self , tensor , out_dim = 3 ):
259
259
head_size = self .heads
260
260
batch_size , seq_len , dim = tensor .shape
261
261
tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
262
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
262
+ tensor = tensor .permute (0 , 2 , 1 , 3 )
263
+
264
+ if out_dim == 3 :
265
+ tensor = tensor .reshape (batch_size * head_size , seq_len , dim // head_size )
266
+
263
267
return tensor
264
268
265
269
def get_attention_scores (self , query , key , attention_mask = None ):
@@ -293,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None):
293
297
294
298
return attention_probs
295
299
296
- def prepare_attention_mask (self , attention_mask , target_length , batch_size = None ):
300
+ def prepare_attention_mask (self , attention_mask , target_length , batch_size = None , out_dim = 3 ):
297
301
if batch_size is None :
298
302
deprecate (
299
303
"batch_size=None" ,
@@ -320,8 +324,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
320
324
else :
321
325
attention_mask = F .pad (attention_mask , (0 , target_length ), value = 0.0 )
322
326
323
- if attention_mask .shape [0 ] < batch_size * head_size :
324
- attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
327
+ if out_dim == 3 :
328
+ if attention_mask .shape [0 ] < batch_size * head_size :
329
+ attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
330
+ elif out_dim == 4 :
331
+ attention_mask = attention_mask .unsqueeze (1 )
332
+ attention_mask = attention_mask .repeat_interleave (head_size , dim = 1 )
333
+
325
334
return attention_mask
326
335
327
336
def norm_encoder_hidden_states (self , encoder_hidden_states ):
@@ -499,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
499
508
return hidden_states
500
509
501
510
511
+ class AttnAddedKVProcessor2_0 :
512
+ def __init__ (self ):
513
+ if not hasattr (F , "scaled_dot_product_attention" ):
514
+ raise ImportError (
515
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
516
+ )
517
+
518
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
519
+ residual = hidden_states
520
+ hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
521
+ batch_size , sequence_length , _ = hidden_states .shape
522
+
523
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size , out_dim = 4 )
524
+
525
+ if encoder_hidden_states is None :
526
+ encoder_hidden_states = hidden_states
527
+ elif attn .norm_cross :
528
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
529
+
530
+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
531
+
532
+ query = attn .to_q (hidden_states )
533
+ query = attn .head_to_batch_dim (query , out_dim = 4 )
534
+
535
+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
536
+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
537
+ encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj , out_dim = 4 )
538
+ encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj , out_dim = 4 )
539
+
540
+ if not attn .only_cross_attention :
541
+ key = attn .to_k (hidden_states )
542
+ value = attn .to_v (hidden_states )
543
+ key = attn .head_to_batch_dim (key , out_dim = 4 )
544
+ value = attn .head_to_batch_dim (value , out_dim = 4 )
545
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
546
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
547
+ else :
548
+ key = encoder_hidden_states_key_proj
549
+ value = encoder_hidden_states_value_proj
550
+
551
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
552
+ # TODO: add support for attn.scale when we move to Torch 2.1
553
+ hidden_states = F .scaled_dot_product_attention (
554
+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
555
+ )
556
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , residual .shape [1 ])
557
+
558
+ # linear proj
559
+ hidden_states = attn .to_out [0 ](hidden_states )
560
+ # dropout
561
+ hidden_states = attn .to_out [1 ](hidden_states )
562
+
563
+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (residual .shape )
564
+ hidden_states = hidden_states + residual
565
+
566
+ return hidden_states
567
+
568
+
502
569
class XFormersAttnProcessor :
503
570
def __init__ (self , attention_op : Optional [Callable ] = None ):
504
571
self .attention_op = attention_op
@@ -764,6 +831,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
764
831
SlicedAttnProcessor ,
765
832
AttnAddedKVProcessor ,
766
833
SlicedAttnAddedKVProcessor ,
834
+ AttnAddedKVProcessor2_0 ,
767
835
LoRAAttnProcessor ,
768
836
LoRAXFormersAttnProcessor ,
769
837
]
0 commit comments