Skip to content

Commit 9160e51

Browse files
committed
prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten
1 parent 96e8671 commit 9160e51

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
311311
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
312312
return attention_mask
313313

314-
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
315-
if encoder_hidden_states is None:
316-
return hidden_states
317-
314+
def norm_encoder_hidden_states(self, encoder_hidden_states):
318315
if self.norm_cross is None:
319316
return encoder_hidden_states
320317

@@ -349,7 +346,10 @@ def __call__(
349346
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
350347
query = attn.to_q(hidden_states)
351348

352-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
349+
if encoder_hidden_states is None:
350+
encoder_hidden_states = hidden_states
351+
else:
352+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
353353

354354
key = attn.to_k(encoder_hidden_states)
355355
value = attn.to_v(encoder_hidden_states)
@@ -415,7 +415,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
415415
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
416416
query = attn.head_to_batch_dim(query)
417417

418-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
418+
if encoder_hidden_states is None:
419+
encoder_hidden_states = hidden_states
420+
else:
421+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
419422

420423
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
421424
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -442,7 +445,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
442445
batch_size, sequence_length, _ = hidden_states.shape
443446

444447
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
445-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
448+
449+
if encoder_hidden_states is None:
450+
encoder_hidden_states = hidden_states
451+
else:
452+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
446453

447454
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
448455

@@ -490,7 +497,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
490497

491498
query = attn.to_q(hidden_states)
492499

493-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
500+
if encoder_hidden_states is None:
501+
encoder_hidden_states = hidden_states
502+
else:
503+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
494504

495505
key = attn.to_k(encoder_hidden_states)
496506
value = attn.to_v(encoder_hidden_states)
@@ -531,7 +541,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
531541

532542
query = attn.to_q(hidden_states)
533543

534-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
544+
if encoder_hidden_states is None:
545+
encoder_hidden_states = hidden_states
546+
else:
547+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
535548

536549
key = attn.to_k(encoder_hidden_states)
537550
value = attn.to_v(encoder_hidden_states)
@@ -580,7 +593,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580593
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
581594
query = attn.head_to_batch_dim(query).contiguous()
582595

583-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
596+
if encoder_hidden_states is None:
597+
encoder_hidden_states = hidden_states
598+
else:
599+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
584600

585601
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
586602
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -615,7 +631,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
615631
dim = query.shape[-1]
616632
query = attn.head_to_batch_dim(query)
617633

618-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
634+
if encoder_hidden_states is None:
635+
encoder_hidden_states = hidden_states
636+
else:
637+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
619638

620639
key = attn.to_k(encoder_hidden_states)
621640
value = attn.to_v(encoder_hidden_states)
@@ -662,7 +681,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
662681
batch_size, sequence_length, _ = hidden_states.shape
663682

664683
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
665-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
684+
685+
if encoder_hidden_states is None:
686+
encoder_hidden_states = hidden_states
687+
else:
688+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
666689

667690
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
668691

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,10 @@ def __call__(
241241
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242242
query = attn.to_q(hidden_states)
243243

244-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
244+
if encoder_hidden_states is None:
245+
encoder_hidden_states = hidden_states
246+
else:
247+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
245248

246249
key = attn.to_k(encoder_hidden_states)
247250
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def __call__(
6363
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6464
query = attn.to_q(hidden_states)
6565

66-
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
66+
if encoder_hidden_states is None:
67+
encoder_hidden_states = hidden_states
68+
else:
69+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
6770

6871
key = attn.to_k(encoder_hidden_states)
6972
value = attn.to_v(encoder_hidden_states)

0 commit comments

Comments
 (0)