@@ -311,10 +311,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
311
311
attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
312
312
return attention_mask
313
313
314
- def prepare_encoder_hidden_states (self , hidden_states , encoder_hidden_states = None ):
315
- if encoder_hidden_states is None :
316
- return hidden_states
317
-
314
+ def norm_encoder_hidden_states (self , encoder_hidden_states ):
318
315
if self .norm_cross is None :
319
316
return encoder_hidden_states
320
317
@@ -349,7 +346,10 @@ def __call__(
349
346
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
350
347
query = attn .to_q (hidden_states )
351
348
352
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
349
+ if encoder_hidden_states is None :
350
+ encoder_hidden_states = hidden_states
351
+ else :
352
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
353
353
354
354
key = attn .to_k (encoder_hidden_states )
355
355
value = attn .to_v (encoder_hidden_states )
@@ -415,7 +415,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
415
415
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
416
416
query = attn .head_to_batch_dim (query )
417
417
418
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
418
+ if encoder_hidden_states is None :
419
+ encoder_hidden_states = hidden_states
420
+ else :
421
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
419
422
420
423
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
421
424
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -442,7 +445,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
442
445
batch_size , sequence_length , _ = hidden_states .shape
443
446
444
447
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
445
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
448
+
449
+ if encoder_hidden_states is None :
450
+ encoder_hidden_states = hidden_states
451
+ else :
452
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
446
453
447
454
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
448
455
@@ -490,7 +497,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
490
497
491
498
query = attn .to_q (hidden_states )
492
499
493
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
500
+ if encoder_hidden_states is None :
501
+ encoder_hidden_states = hidden_states
502
+ else :
503
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
494
504
495
505
key = attn .to_k (encoder_hidden_states )
496
506
value = attn .to_v (encoder_hidden_states )
@@ -531,7 +541,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
531
541
532
542
query = attn .to_q (hidden_states )
533
543
534
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
544
+ if encoder_hidden_states is None :
545
+ encoder_hidden_states = hidden_states
546
+ else :
547
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
535
548
536
549
key = attn .to_k (encoder_hidden_states )
537
550
value = attn .to_v (encoder_hidden_states )
@@ -580,7 +593,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580
593
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
581
594
query = attn .head_to_batch_dim (query ).contiguous ()
582
595
583
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
596
+ if encoder_hidden_states is None :
597
+ encoder_hidden_states = hidden_states
598
+ else :
599
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
584
600
585
601
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
586
602
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -615,7 +631,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
615
631
dim = query .shape [- 1 ]
616
632
query = attn .head_to_batch_dim (query )
617
633
618
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
634
+ if encoder_hidden_states is None :
635
+ encoder_hidden_states = hidden_states
636
+ else :
637
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
619
638
620
639
key = attn .to_k (encoder_hidden_states )
621
640
value = attn .to_v (encoder_hidden_states )
@@ -662,7 +681,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
662
681
batch_size , sequence_length , _ = hidden_states .shape
663
682
664
683
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
665
- encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
684
+
685
+ if encoder_hidden_states is None :
686
+ encoder_hidden_states = hidden_states
687
+ else :
688
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
666
689
667
690
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
668
691
0 commit comments