diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index b6a93672c733..c3f19ec8b452 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -304,11 +304,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -361,26 +356,23 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - bs_embed = pooled_prompt_embeds.shape[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 6f0ecf4df8d6..525f6b421c30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -319,11 +319,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -376,26 +371,23 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - bs_embed = pooled_prompt_embeds.shape[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 302e9347beeb..455fc352ff16 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -327,11 +327,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -384,26 +379,23 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - bs_embed = pooled_prompt_embeds.shape[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index bc55d1322839..6f3102986792 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -433,11 +433,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -490,26 +485,23 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - bs_embed = pooled_prompt_embeds.shape[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 5b830d09de1d..0029d2b55cdd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -148,6 +148,44 @@ def test_stable_diffusion_xl_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 2 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + inputs = self.get_dummy_inputs(torch_device) + prompt = 2 * [inputs.pop("prompt")] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_negative_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**components)