Skip to content

Commit 8e5921c

Browse files
fix a bug of prompt embeds in sdxl (#4099)
* fix bug in sdxl * Update pipeline_stable_diffusion_xl_img2img.py * Update pipeline_stable_diffusion_xl.py * Update pipeline_stable_diffusion_xl_img2img.py * Update pipeline_stable_diffusion_xl_inpaint.py * Update pipeline_stable_diffusion_xl.py * Update pipeline_stable_diffusion_xl_img2img.py * Update pipeline_stable_diffusion_xl_inpaint.py * Update pipeline_stable_diffusion_xl_img2img.py * Update pipeline_controlnet_sd_xl.py * Update pipeline_controlnet_sd_xl.py * Update pipeline_stable_diffusion_xl.py * Update pipeline_stable_diffusion_xl_img2img.py * Update pipeline_stable_diffusion_xl_inpaint.py * Update test_stable_diffusion_xl.py * Update test_stable_diffusion_xl.py * Update test_stable_diffusion_xl.py add test on prompt_embeds * add test on prompt_embeds --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 8e8954b commit 8e5921c

File tree

5 files changed

+90
-84
lines changed

5 files changed

+90
-84
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,6 @@ def encode_prompt(
304304
pooled_prompt_embeds = prompt_embeds[0]
305305
prompt_embeds = prompt_embeds.hidden_states[-2]
306306

307-
bs_embed, seq_len, _ = prompt_embeds.shape
308-
# duplicate text embeddings for each generation per prompt, using mps friendly method
309-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
310-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
311-
312307
prompt_embeds_list.append(prompt_embeds)
313308

314309
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
@@ -361,26 +356,23 @@ def encode_prompt(
361356
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
362357
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
363358

364-
if do_classifier_free_guidance:
365-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
366-
seq_len = negative_prompt_embeds.shape[1]
367-
368-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
369-
370-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
371-
negative_prompt_embeds = negative_prompt_embeds.view(
372-
batch_size * num_images_per_prompt, seq_len, -1
373-
)
374-
375-
# For classifier free guidance, we need to do two forward passes.
376-
# Here we concatenate the unconditional and text embeddings into a single batch
377-
# to avoid doing two forward passes
378-
379359
negative_prompt_embeds_list.append(negative_prompt_embeds)
380360

381361
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
382362

383-
bs_embed = pooled_prompt_embeds.shape[0]
363+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
364+
bs_embed, seq_len, _ = prompt_embeds.shape
365+
# duplicate text embeddings for each generation per prompt, using mps friendly method
366+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
367+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
368+
369+
if do_classifier_free_guidance:
370+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
371+
seq_len = negative_prompt_embeds.shape[1]
372+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
373+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
374+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
375+
384376
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
385377
bs_embed * num_images_per_prompt, -1
386378
)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,6 @@ def encode_prompt(
319319
pooled_prompt_embeds = prompt_embeds[0]
320320
prompt_embeds = prompt_embeds.hidden_states[-2]
321321

322-
bs_embed, seq_len, _ = prompt_embeds.shape
323-
# duplicate text embeddings for each generation per prompt, using mps friendly method
324-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
325-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
326-
327322
prompt_embeds_list.append(prompt_embeds)
328323

329324
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
@@ -376,26 +371,23 @@ def encode_prompt(
376371
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
377372
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
378373

379-
if do_classifier_free_guidance:
380-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
381-
seq_len = negative_prompt_embeds.shape[1]
382-
383-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
384-
385-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
386-
negative_prompt_embeds = negative_prompt_embeds.view(
387-
batch_size * num_images_per_prompt, seq_len, -1
388-
)
389-
390-
# For classifier free guidance, we need to do two forward passes.
391-
# Here we concatenate the unconditional and text embeddings into a single batch
392-
# to avoid doing two forward passes
393-
394374
negative_prompt_embeds_list.append(negative_prompt_embeds)
395375

396376
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
397377

398-
bs_embed = pooled_prompt_embeds.shape[0]
378+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
379+
bs_embed, seq_len, _ = prompt_embeds.shape
380+
# duplicate text embeddings for each generation per prompt, using mps friendly method
381+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
382+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
383+
384+
if do_classifier_free_guidance:
385+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
386+
seq_len = negative_prompt_embeds.shape[1]
387+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
388+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
389+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
390+
399391
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
400392
bs_embed * num_images_per_prompt, -1
401393
)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,6 @@ def encode_prompt(
327327
pooled_prompt_embeds = prompt_embeds[0]
328328
prompt_embeds = prompt_embeds.hidden_states[-2]
329329

330-
bs_embed, seq_len, _ = prompt_embeds.shape
331-
# duplicate text embeddings for each generation per prompt, using mps friendly method
332-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
333-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
334-
335330
prompt_embeds_list.append(prompt_embeds)
336331

337332
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
@@ -384,26 +379,23 @@ def encode_prompt(
384379
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
385380
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
386381

387-
if do_classifier_free_guidance:
388-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
389-
seq_len = negative_prompt_embeds.shape[1]
390-
391-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
392-
393-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
394-
negative_prompt_embeds = negative_prompt_embeds.view(
395-
batch_size * num_images_per_prompt, seq_len, -1
396-
)
397-
398-
# For classifier free guidance, we need to do two forward passes.
399-
# Here we concatenate the unconditional and text embeddings into a single batch
400-
# to avoid doing two forward passes
401-
402382
negative_prompt_embeds_list.append(negative_prompt_embeds)
403383

404384
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
405385

406-
bs_embed = pooled_prompt_embeds.shape[0]
386+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
387+
bs_embed, seq_len, _ = prompt_embeds.shape
388+
# duplicate text embeddings for each generation per prompt, using mps friendly method
389+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
390+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
391+
392+
if do_classifier_free_guidance:
393+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
394+
seq_len = negative_prompt_embeds.shape[1]
395+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
396+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
398+
407399
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
408400
bs_embed * num_images_per_prompt, -1
409401
)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,6 @@ def encode_prompt(
433433
pooled_prompt_embeds = prompt_embeds[0]
434434
prompt_embeds = prompt_embeds.hidden_states[-2]
435435

436-
bs_embed, seq_len, _ = prompt_embeds.shape
437-
# duplicate text embeddings for each generation per prompt, using mps friendly method
438-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
439-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
440-
441436
prompt_embeds_list.append(prompt_embeds)
442437

443438
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
@@ -490,26 +485,23 @@ def encode_prompt(
490485
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
491486
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
492487

493-
if do_classifier_free_guidance:
494-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
495-
seq_len = negative_prompt_embeds.shape[1]
496-
497-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
498-
499-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
500-
negative_prompt_embeds = negative_prompt_embeds.view(
501-
batch_size * num_images_per_prompt, seq_len, -1
502-
)
503-
504-
# For classifier free guidance, we need to do two forward passes.
505-
# Here we concatenate the unconditional and text embeddings into a single batch
506-
# to avoid doing two forward passes
507-
508488
negative_prompt_embeds_list.append(negative_prompt_embeds)
509489

510490
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
511491

512-
bs_embed = pooled_prompt_embeds.shape[0]
492+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
493+
bs_embed, seq_len, _ = prompt_embeds.shape
494+
# duplicate text embeddings for each generation per prompt, using mps friendly method
495+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
496+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
497+
498+
if do_classifier_free_guidance:
499+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
500+
seq_len = negative_prompt_embeds.shape[1]
501+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
502+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
503+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
504+
513505
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
514506
bs_embed * num_images_per_prompt, -1
515507
)

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,44 @@ def test_stable_diffusion_xl_euler(self):
148148

149149
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
150150

151+
def test_stable_diffusion_xl_prompt_embeds(self):
152+
components = self.get_dummy_components()
153+
sd_pipe = StableDiffusionXLPipeline(**components)
154+
sd_pipe = sd_pipe.to(torch_device)
155+
sd_pipe = sd_pipe.to(torch_device)
156+
sd_pipe.set_progress_bar_config(disable=None)
157+
158+
# forward without prompt embeds
159+
inputs = self.get_dummy_inputs(torch_device)
160+
inputs["prompt"] = 2 * [inputs["prompt"]]
161+
inputs["num_images_per_prompt"] = 2
162+
163+
output = sd_pipe(**inputs)
164+
image_slice_1 = output.images[0, -3:, -3:, -1]
165+
166+
# forward with prompt embeds
167+
inputs = self.get_dummy_inputs(torch_device)
168+
prompt = 2 * [inputs.pop("prompt")]
169+
170+
(
171+
prompt_embeds,
172+
negative_prompt_embeds,
173+
pooled_prompt_embeds,
174+
negative_pooled_prompt_embeds,
175+
) = sd_pipe.encode_prompt(prompt)
176+
177+
output = sd_pipe(
178+
**inputs,
179+
prompt_embeds=prompt_embeds,
180+
negative_prompt_embeds=negative_prompt_embeds,
181+
pooled_prompt_embeds=pooled_prompt_embeds,
182+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
183+
)
184+
image_slice_2 = output.images[0, -3:, -3:, -1]
185+
186+
# make sure that it's equal
187+
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
188+
151189
def test_stable_diffusion_xl_negative_prompt_embeds(self):
152190
components = self.get_dummy_components()
153191
sd_pipe = StableDiffusionXLPipeline(**components)

0 commit comments

Comments
 (0)