Skip to content

Commit b8d88b8

Browse files
committed
fix simple attention processor encoder hidden states ordering
1 parent ce144d6 commit b8d88b8

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
400400
residual = hidden_states
401401
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
402402
batch_size, sequence_length, _ = hidden_states.shape
403-
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
404403

405404
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
406405

@@ -627,7 +626,6 @@ def __init__(self, slice_size):
627626
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
628627
residual = hidden_states
629628
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
630-
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
631629

632630
batch_size, sequence_length, _ = hidden_states.shape
633631

src/diffusers/pipelines/unclip/text_proj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states
7777
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
7878
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
7979
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
80+
clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)
8081

8182
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
8283
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
83-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
84-
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
84+
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)
8585

8686
return text_encoder_hidden_states, additive_clip_time_embeddings

tests/pipelines/unclip/test_unclip_image_variation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
5454
"decoder_num_inference_steps",
5555
"super_res_num_inference_steps",
5656
]
57+
test_xformers_attention = False
5758

5859
@property
5960
def text_embedder_hidden_size(self):

0 commit comments

Comments
 (0)