Skip to content

fix a bug of prompt embeds in sdxl #4099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4cfd644
fix bug in sdxl
xiaohu2015 Jul 14, 2023
cb902a7
Update pipeline_stable_diffusion_xl_img2img.py
xiaohu2015 Jul 14, 2023
af1b559
Update pipeline_stable_diffusion_xl.py
xiaohu2015 Jul 14, 2023
408e221
Update pipeline_stable_diffusion_xl_img2img.py
xiaohu2015 Jul 14, 2023
95735d6
Merge branch 'main' into patch-3
patrickvonplaten Jul 18, 2023
462f3a6
Merge branch 'huggingface:main' into patch-3
xiaohu2015 Jul 19, 2023
2b228b3
Update pipeline_stable_diffusion_xl_inpaint.py
xiaohu2015 Jul 19, 2023
0367506
Update pipeline_stable_diffusion_xl.py
xiaohu2015 Jul 21, 2023
17a40d0
Update pipeline_stable_diffusion_xl_img2img.py
xiaohu2015 Jul 21, 2023
dfc8072
Update pipeline_stable_diffusion_xl_inpaint.py
xiaohu2015 Jul 21, 2023
a124687
Update pipeline_stable_diffusion_xl_img2img.py
xiaohu2015 Jul 21, 2023
6bad0ca
Merge branch 'huggingface:main' into patch-3
xiaohu2015 Jul 21, 2023
49c690b
Update pipeline_controlnet_sd_xl.py
xiaohu2015 Jul 21, 2023
fdd3488
Update pipeline_controlnet_sd_xl.py
xiaohu2015 Jul 21, 2023
fa514a8
Update pipeline_stable_diffusion_xl.py
xiaohu2015 Jul 21, 2023
c1889a8
Update pipeline_stable_diffusion_xl_img2img.py
xiaohu2015 Jul 21, 2023
9066020
Update pipeline_stable_diffusion_xl_inpaint.py
xiaohu2015 Jul 21, 2023
63093d8
Merge branch 'main' into patch-3
xiaohu2015 Jul 21, 2023
c8afd87
Update test_stable_diffusion_xl.py
xiaohu2015 Jul 21, 2023
110431b
Merge branch 'huggingface:main' into patch-3
xiaohu2015 Jul 22, 2023
36884fd
Update test_stable_diffusion_xl.py
xiaohu2015 Jul 22, 2023
26d0501
Update test_stable_diffusion_xl.py
xiaohu2015 Jul 22, 2023
94a23d1
add test on prompt_embeds
xiaohu2015 Jul 22, 2023
d23ae77
Merge branch 'main' into patch-3
xiaohu2015 Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -361,26 +356,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -376,26 +371,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -384,26 +379,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -490,26 +485,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
38 changes: 38 additions & 0 deletions tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,44 @@ def test_stable_diffusion_xl_euler(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2

output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]

# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]

(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)

output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]

# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4

def test_stable_diffusion_xl_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
Expand Down