Skip to content

Commit fdffad4

Browse files
committed
1 parent 280285f commit fdffad4

File tree

1 file changed

+31
-66
lines changed

1 file changed

+31
-66
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -305,28 +305,22 @@ def __call__(
305305
encoder_hidden_states=None,
306306
attention_mask=None,
307307
):
308-
if attn.residual_connection:
309-
residual = hidden_states
308+
residual = hidden_states
310309

311310
if attn.group_norm is not None:
312311
hidden_states = attn.group_norm(hidden_states)
313312

314313
batch_size = hidden_states.shape[0]
315314

316-
if hidden_states.ndim == 4:
317-
reshaped_input = True
315+
input_ndim = hidden_states.ndim
318316

317+
if input_ndim == 4:
319318
_, channel, height, width = hidden_states.shape
320319

321320
hidden_states = hidden_states.view(batch_size, channel, height * width)
322321
hidden_states = hidden_states.transpose(1, 2)
323-
else:
324-
reshaped_input = False
325322

326-
if encoder_hidden_states is None:
327-
sequence_length = hidden_states.shape[1]
328-
else:
329-
sequence_length = encoder_hidden_states.shape[1]
323+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
330324

331325
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332326

@@ -353,7 +347,7 @@ def __call__(
353347
# dropout
354348
hidden_states = attn.to_out[1](hidden_states)
355349

356-
if reshaped_input:
350+
if input_ndim == 4:
357351
hidden_states = hidden_states.transpose(1, 2)
358352
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
359353

@@ -402,28 +396,22 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
402396
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
403397

404398
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
405-
if attn.residual_connection:
406-
residual = hidden_states
399+
residual = hidden_states
407400

408401
if attn.group_norm is not None:
409402
hidden_states = attn.group_norm(hidden_states)
410403

411404
batch_size = hidden_states.shape[0]
412405

413-
if hidden_states.ndim == 4:
414-
reshaped_input = True
406+
input_ndim = hidden_states.ndim
415407

408+
if input_ndim == 4:
416409
_, channel, height, width = hidden_states.shape
417410

418411
hidden_states = hidden_states.view(batch_size, channel, height * width)
419412
hidden_states = hidden_states.transpose(1, 2)
420-
else:
421-
reshaped_input = False
422413

423-
if encoder_hidden_states is None:
424-
sequence_length = hidden_states.shape[1]
425-
else:
426-
sequence_length = encoder_hidden_states.shape[1]
414+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
427415

428416
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
429417

@@ -447,7 +435,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
447435
# dropout
448436
hidden_states = attn.to_out[1](hidden_states)
449437

450-
if reshaped_input:
438+
if input_ndim == 4:
451439
hidden_states = hidden_states.transpose(1, 2)
452440
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
453441

@@ -506,28 +494,23 @@ def __init__(self, attention_op: Optional[Callable] = None):
506494
self.attention_op = attention_op
507495

508496
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
509-
if attn.residual_connection:
510-
residual = hidden_states
497+
residual = hidden_states
511498

512499
if attn.group_norm is not None:
513500
hidden_states = attn.group_norm(hidden_states)
514501

515502
batch_size = hidden_states.shape[0]
516503

517-
if hidden_states.ndim == 4:
518-
reshaped_input = True
504+
input_ndim = hidden_states.ndim
505+
506+
if input_ndim == 4:
519507

520508
_, channel, height, width = hidden_states.shape
521509

522510
hidden_states = hidden_states.view(batch_size, channel, height * width)
523511
hidden_states = hidden_states.transpose(1, 2)
524-
else:
525-
reshaped_input = False
526512

527-
if encoder_hidden_states is None:
528-
sequence_length = hidden_states.shape[1]
529-
else:
530-
sequence_length = encoder_hidden_states.shape[1]
513+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
531514

532515
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
533516

@@ -556,7 +539,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
556539
# dropout
557540
hidden_states = attn.to_out[1](hidden_states)
558541

559-
if reshaped_input:
542+
if input_ndim == 4:
560543
hidden_states = hidden_states.transpose(1, 2)
561544
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
562545

@@ -574,28 +557,22 @@ def __init__(self):
574557
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
575558

576559
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
577-
if attn.residual_connection:
578-
residual = hidden_states
560+
residual = hidden_states
579561

580562
if attn.group_norm is not None:
581563
hidden_states = attn.group_norm(hidden_states)
582564

583565
batch_size = hidden_states.shape[0]
584566

585-
if hidden_states.ndim == 4:
586-
reshaped_input = True
567+
input_ndim = hidden_states.ndim
587568

569+
if input_ndim == 4:
588570
_, channel, height, width = hidden_states.shape
589571

590572
hidden_states = hidden_states.view(batch_size, channel, height * width)
591573
hidden_states = hidden_states.transpose(1, 2)
592-
else:
593-
reshaped_input = False
594574

595-
if encoder_hidden_states is None:
596-
sequence_length = hidden_states.shape[1]
597-
else:
598-
sequence_length = encoder_hidden_states.shape[1]
575+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
599576

600577
inner_dim = hidden_states.shape[-1]
601578

@@ -634,7 +611,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
634611
# dropout
635612
hidden_states = attn.to_out[1](hidden_states)
636613

637-
if reshaped_input:
614+
if input_ndim == 4:
638615
hidden_states = hidden_states.transpose(1, 2)
639616
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
640617

@@ -661,28 +638,22 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
661638
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
662639

663640
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
664-
if attn.residual_connection:
665-
residual = hidden_states
641+
residual = hidden_states
666642

667643
if attn.group_norm:
668644
hidden_states = attn.group_norm(hidden_states)
669645

670646
batch_size = hidden_states.shape[0]
671647

672-
if hidden_states.ndim == 4:
673-
reshaped_input = True
648+
input_ndim = hidden_states.ndim
674649

650+
if input_ndim == 4:
675651
_, channel, height, width = hidden_states.shape
676652

677653
hidden_states = hidden_states.view(batch_size, channel, height * width)
678654
hidden_states = hidden_states.transpose(1, 2)
679-
else:
680-
reshaped_input = False
681655

682-
if encoder_hidden_states is None:
683-
sequence_length = hidden_states.shape[1]
684-
else:
685-
sequence_length = encoder_hidden_states.shape[1]
656+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
686657

687658
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
688659

@@ -707,7 +678,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
707678
# dropout
708679
hidden_states = attn.to_out[1](hidden_states)
709680

710-
if reshaped_input:
681+
if input_ndim == 4:
711682
hidden_states = hidden_states.transpose(1, 2)
712683
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
713684

@@ -724,28 +695,22 @@ def __init__(self, slice_size):
724695
self.slice_size = slice_size
725696

726697
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
727-
if attn.residual_connection:
728-
residual = hidden_states
698+
residual = hidden_states
729699

730700
if attn.group_norm is not None:
731701
hidden_states = attn.group_norm(hidden_states)
732702

733703
batch_size = hidden_states.shape[0]
734704

735-
if hidden_states.ndim == 4:
736-
reshaped_input = True
705+
input_ndim = hidden_states.ndim
737706

707+
if input_ndim == 4:
738708
_, channel, height, width = hidden_states.shape
739709

740710
hidden_states = hidden_states.view(batch_size, channel, height * width)
741711
hidden_states = hidden_states.transpose(1, 2)
742-
else:
743-
reshaped_input = False
744712

745-
if encoder_hidden_states is None:
746-
sequence_length = hidden_states.shape[1]
747-
else:
748-
sequence_length = encoder_hidden_states.shape[1]
713+
sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1]
749714

750715
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
751716

@@ -789,7 +754,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
789754
# dropout
790755
hidden_states = attn.to_out[1](hidden_states)
791756

792-
if reshaped_input:
757+
if input_ndim == 4:
793758
hidden_states = hidden_states.transpose(1, 2)
794759
hidden_states = hidden_states.reshape(batch_size, channel, height, width)
795760

0 commit comments

Comments
 (0)