From 4cfd6442ac5533b30dfa3369fc9b81d29323e38d Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 14 Jul 2023 20:33:59 +0800 Subject: [PATCH 01/18] fix bug in sdxl --- .../pipeline_stable_diffusion_xl.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b3dcf1b67cda..97c66df76d4e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -341,15 +341,16 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -397,24 +398,20 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From cb902a7d353c8b9f89409208e0a00b98699e6c7c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 14 Jul 2023 20:36:32 +0800 Subject: [PATCH 02/18] Update pipeline_stable_diffusion_xl_img2img.py --- .../pipeline_stable_diffusion_xl_img2img.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 7b0cdfad8c0a..3a9fa32f1b71 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -352,15 +352,16 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -408,24 +409,20 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From af1b559930f02115e77d24dc83cad6024b017ce0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 14 Jul 2023 21:23:14 +0800 Subject: [PATCH 03/18] Update pipeline_stable_diffusion_xl.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 97c66df76d4e..e705d1fa6652 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -406,7 +406,7 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( From 408e22119c4ba71bdfa2711d0174962603dc06c5 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 14 Jul 2023 21:23:42 +0800 Subject: [PATCH 04/18] Update pipeline_stable_diffusion_xl_img2img.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 3a9fa32f1b71..3e2520fe5778 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -417,7 +417,7 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( From 2b228b3f6f5c5d30cda7bba62b3f96dfcb900f1a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 19 Jul 2023 22:09:07 +0800 Subject: [PATCH 05/18] Update pipeline_stable_diffusion_xl_inpaint.py --- .../pipeline_stable_diffusion_xl_inpaint.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 55f20660afc7..4282028ced03 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -424,15 +424,16 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -480,24 +481,20 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From 03675066d4ddcf31b41837bf55ac4c153bf6e32c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 17:29:38 +0800 Subject: [PATCH 06/18] Update pipeline_stable_diffusion_xl.py --- .../pipeline_stable_diffusion_xl.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 67cdd803d819..e104700622cc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -314,12 +314,6 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -371,16 +365,18 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From 17a40d04ce538efce6842fd1ef953e0a4868cb16 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 18:52:25 +0800 Subject: [PATCH 07/18] Update pipeline_stable_diffusion_xl_img2img.py --- .../pipeline_stable_diffusion_xl_img2img.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index d4083f8c5442..3cccd69e4367 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -322,12 +322,6 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -379,16 +373,18 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From dfc80722376162f12e74fab6aaf7337885b813f7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 18:52:59 +0800 Subject: [PATCH 08/18] Update pipeline_stable_diffusion_xl_inpaint.py --- .../pipeline_stable_diffusion_xl_inpaint.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 4282028ced03..e541e4a5eece 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -428,12 +428,6 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: @@ -485,16 +479,18 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( From a124687ce1177be4ce539c53de67ce62eeb18a97 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 18:55:11 +0800 Subject: [PATCH 09/18] Update pipeline_stable_diffusion_xl_img2img.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 3cccd69e4367..103177401887 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -372,7 +372,7 @@ def encode_prompt( negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method From 49c690b40a5afc2db4418d241f92bde7b6b73a34 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 19:12:32 +0800 Subject: [PATCH 10/18] Update pipeline_controlnet_sd_xl.py --- .../controlnet/pipeline_controlnet_sd_xl.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index fd4338946e6f..392080b09644 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -295,11 +295,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -351,25 +346,22 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 From fdd3488f6265f58b35b5e1e8f504999209f7e326 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 19:33:07 +0800 Subject: [PATCH 11/18] Update pipeline_controlnet_sd_xl.py --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 392080b09644..1ac505cb0ba4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -362,7 +362,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) From fa514a8cd9cb495ac2e776bb9fef4fa2463d78fb Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 19:33:31 +0800 Subject: [PATCH 12/18] Update pipeline_stable_diffusion_xl.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e104700622cc..bc0ef4e1af54 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -378,7 +378,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) From c1889a8d162f481ba04037599419b266b17cc20a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 19:34:00 +0800 Subject: [PATCH 13/18] Update pipeline_stable_diffusion_xl_img2img.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 103177401887..a9a541ab585f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -386,7 +386,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) From 90660206724db44aa32110ec6aa7674b3e8ca1ec Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 19:34:22 +0800 Subject: [PATCH 14/18] Update pipeline_stable_diffusion_xl_inpaint.py --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index e541e4a5eece..2c4ccc7a7817 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -492,7 +492,6 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) From c8afd875ae17d6e6a9acecca376fa1505919808c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 21 Jul 2023 21:38:55 +0800 Subject: [PATCH 15/18] Update test_stable_diffusion_xl.py --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 947d57a7bee3..a1f986b9d2aa 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -160,6 +160,7 @@ def test_stable_diffusion_xl_negative_prompt_embeds(self): negative_prompt = 3 * ["this is a negative prompt"] inputs["negative_prompt"] = negative_prompt inputs["prompt"] = 3 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 output = sd_pipe(**inputs) image_slice_1 = output.images[0, -3:, -3:, -1] From 36884fdd1fada8f18b60d9ff59213b3971f5dbba Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 22 Jul 2023 15:06:52 +0800 Subject: [PATCH 16/18] Update test_stable_diffusion_xl.py --- .../test_stable_diffusion_xl.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index e9f0ff7f59b1..c09e0e5e5458 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -148,6 +148,44 @@ def test_stable_diffusion_xl_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 2 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_negative_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**components) @@ -160,7 +198,6 @@ def test_stable_diffusion_xl_negative_prompt_embeds(self): negative_prompt = 3 * ["this is a negative prompt"] inputs["negative_prompt"] = negative_prompt inputs["prompt"] = 3 * [inputs["prompt"]] - inputs["num_images_per_prompt"] = 2 output = sd_pipe(**inputs) image_slice_1 = output.images[0, -3:, -3:, -1] From 26d0501984a3b2a258ac55aff1b5fe2018c29cc5 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 22 Jul 2023 15:08:57 +0800 Subject: [PATCH 17/18] Update test_stable_diffusion_xl.py add test on prompt_embeds --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index c09e0e5e5458..b1b7582cf592 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -165,7 +165,7 @@ def test_stable_diffusion_xl_prompt_embeds(self): # forward with prompt embeds inputs = self.get_dummy_inputs(torch_device) - prompt = 3 * [inputs.pop("prompt")] + prompt = 2 * [inputs.pop("prompt")] ( prompt_embeds, From 94a23d11726b2cd4c205dfdf670c15c4f9800f00 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 22 Jul 2023 15:31:22 +0800 Subject: [PATCH 18/18] add test on prompt_embeds --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index b1b7582cf592..0029d2b55cdd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -172,7 +172,7 @@ def test_stable_diffusion_xl_prompt_embeds(self): negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, - ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt) + ) = sd_pipe.encode_prompt(prompt) output = sd_pipe( **inputs,