@@ -305,28 +305,22 @@ def __call__(
305
305
encoder_hidden_states = None ,
306
306
attention_mask = None ,
307
307
):
308
- if attn .residual_connection :
309
- residual = hidden_states
308
+ residual = hidden_states
310
309
311
310
if attn .group_norm is not None :
312
311
hidden_states = attn .group_norm (hidden_states )
313
312
314
313
batch_size = hidden_states .shape [0 ]
315
314
316
- if hidden_states .ndim == 4 :
317
- reshaped_input = True
315
+ input_ndim = hidden_states .ndim
318
316
317
+ if input_ndim == 4 :
319
318
_ , channel , height , width = hidden_states .shape
320
319
321
320
hidden_states = hidden_states .view (batch_size , channel , height * width )
322
321
hidden_states = hidden_states .transpose (1 , 2 )
323
- else :
324
- reshaped_input = False
325
322
326
- if encoder_hidden_states is None :
327
- sequence_length = hidden_states .shape [1 ]
328
- else :
329
- sequence_length = encoder_hidden_states .shape [1 ]
323
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
330
324
331
325
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
332
326
@@ -353,7 +347,7 @@ def __call__(
353
347
# dropout
354
348
hidden_states = attn .to_out [1 ](hidden_states )
355
349
356
- if reshaped_input :
350
+ if input_ndim == 4 :
357
351
hidden_states = hidden_states .transpose (1 , 2 )
358
352
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
359
353
@@ -402,28 +396,22 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
402
396
self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
403
397
404
398
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0 ):
405
- if attn .residual_connection :
406
- residual = hidden_states
399
+ residual = hidden_states
407
400
408
401
if attn .group_norm is not None :
409
402
hidden_states = attn .group_norm (hidden_states )
410
403
411
404
batch_size = hidden_states .shape [0 ]
412
405
413
- if hidden_states .ndim == 4 :
414
- reshaped_input = True
406
+ input_ndim = hidden_states .ndim
415
407
408
+ if input_ndim == 4 :
416
409
_ , channel , height , width = hidden_states .shape
417
410
418
411
hidden_states = hidden_states .view (batch_size , channel , height * width )
419
412
hidden_states = hidden_states .transpose (1 , 2 )
420
- else :
421
- reshaped_input = False
422
413
423
- if encoder_hidden_states is None :
424
- sequence_length = hidden_states .shape [1 ]
425
- else :
426
- sequence_length = encoder_hidden_states .shape [1 ]
414
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
427
415
428
416
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
429
417
@@ -447,7 +435,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
447
435
# dropout
448
436
hidden_states = attn .to_out [1 ](hidden_states )
449
437
450
- if reshaped_input :
438
+ if input_ndim == 4 :
451
439
hidden_states = hidden_states .transpose (1 , 2 )
452
440
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
453
441
@@ -506,28 +494,23 @@ def __init__(self, attention_op: Optional[Callable] = None):
506
494
self .attention_op = attention_op
507
495
508
496
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
509
- if attn .residual_connection :
510
- residual = hidden_states
497
+ residual = hidden_states
511
498
512
499
if attn .group_norm is not None :
513
500
hidden_states = attn .group_norm (hidden_states )
514
501
515
502
batch_size = hidden_states .shape [0 ]
516
503
517
- if hidden_states .ndim == 4 :
518
- reshaped_input = True
504
+ input_ndim = hidden_states .ndim
505
+
506
+ if input_ndim == 4 :
519
507
520
508
_ , channel , height , width = hidden_states .shape
521
509
522
510
hidden_states = hidden_states .view (batch_size , channel , height * width )
523
511
hidden_states = hidden_states .transpose (1 , 2 )
524
- else :
525
- reshaped_input = False
526
512
527
- if encoder_hidden_states is None :
528
- sequence_length = hidden_states .shape [1 ]
529
- else :
530
- sequence_length = encoder_hidden_states .shape [1 ]
513
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
531
514
532
515
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
533
516
@@ -556,7 +539,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
556
539
# dropout
557
540
hidden_states = attn .to_out [1 ](hidden_states )
558
541
559
- if reshaped_input :
542
+ if input_ndim == 4 :
560
543
hidden_states = hidden_states .transpose (1 , 2 )
561
544
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
562
545
@@ -574,28 +557,22 @@ def __init__(self):
574
557
raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
575
558
576
559
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
577
- if attn .residual_connection :
578
- residual = hidden_states
560
+ residual = hidden_states
579
561
580
562
if attn .group_norm is not None :
581
563
hidden_states = attn .group_norm (hidden_states )
582
564
583
565
batch_size = hidden_states .shape [0 ]
584
566
585
- if hidden_states .ndim == 4 :
586
- reshaped_input = True
567
+ input_ndim = hidden_states .ndim
587
568
569
+ if input_ndim == 4 :
588
570
_ , channel , height , width = hidden_states .shape
589
571
590
572
hidden_states = hidden_states .view (batch_size , channel , height * width )
591
573
hidden_states = hidden_states .transpose (1 , 2 )
592
- else :
593
- reshaped_input = False
594
574
595
- if encoder_hidden_states is None :
596
- sequence_length = hidden_states .shape [1 ]
597
- else :
598
- sequence_length = encoder_hidden_states .shape [1 ]
575
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
599
576
600
577
inner_dim = hidden_states .shape [- 1 ]
601
578
@@ -634,7 +611,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
634
611
# dropout
635
612
hidden_states = attn .to_out [1 ](hidden_states )
636
613
637
- if reshaped_input :
614
+ if input_ndim == 4 :
638
615
hidden_states = hidden_states .transpose (1 , 2 )
639
616
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
640
617
@@ -661,28 +638,22 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
661
638
self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
662
639
663
640
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0 ):
664
- if attn .residual_connection :
665
- residual = hidden_states
641
+ residual = hidden_states
666
642
667
643
if attn .group_norm :
668
644
hidden_states = attn .group_norm (hidden_states )
669
645
670
646
batch_size = hidden_states .shape [0 ]
671
647
672
- if hidden_states .ndim == 4 :
673
- reshaped_input = True
648
+ input_ndim = hidden_states .ndim
674
649
650
+ if input_ndim == 4 :
675
651
_ , channel , height , width = hidden_states .shape
676
652
677
653
hidden_states = hidden_states .view (batch_size , channel , height * width )
678
654
hidden_states = hidden_states .transpose (1 , 2 )
679
- else :
680
- reshaped_input = False
681
655
682
- if encoder_hidden_states is None :
683
- sequence_length = hidden_states .shape [1 ]
684
- else :
685
- sequence_length = encoder_hidden_states .shape [1 ]
656
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
686
657
687
658
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
688
659
@@ -707,7 +678,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
707
678
# dropout
708
679
hidden_states = attn .to_out [1 ](hidden_states )
709
680
710
- if reshaped_input :
681
+ if input_ndim == 4 :
711
682
hidden_states = hidden_states .transpose (1 , 2 )
712
683
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
713
684
@@ -724,28 +695,22 @@ def __init__(self, slice_size):
724
695
self .slice_size = slice_size
725
696
726
697
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
727
- if attn .residual_connection :
728
- residual = hidden_states
698
+ residual = hidden_states
729
699
730
700
if attn .group_norm is not None :
731
701
hidden_states = attn .group_norm (hidden_states )
732
702
733
703
batch_size = hidden_states .shape [0 ]
734
704
735
- if hidden_states .ndim == 4 :
736
- reshaped_input = True
705
+ input_ndim = hidden_states .ndim
737
706
707
+ if input_ndim == 4 :
738
708
_ , channel , height , width = hidden_states .shape
739
709
740
710
hidden_states = hidden_states .view (batch_size , channel , height * width )
741
711
hidden_states = hidden_states .transpose (1 , 2 )
742
- else :
743
- reshaped_input = False
744
712
745
- if encoder_hidden_states is None :
746
- sequence_length = hidden_states .shape [1 ]
747
- else :
748
- sequence_length = encoder_hidden_states .shape [1 ]
713
+ sequence_length = hidden_states .shape [1 ] if encoder_hidden_states is None else encoder_hidden_states .shape [1 ]
749
714
750
715
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
751
716
@@ -789,7 +754,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
789
754
# dropout
790
755
hidden_states = attn .to_out [1 ](hidden_states )
791
756
792
- if reshaped_input :
757
+ if input_ndim == 4 :
793
758
hidden_states = hidden_states .transpose (1 , 2 )
794
759
hidden_states = hidden_states .reshape (batch_size , channel , height , width )
795
760
0 commit comments