Skip to content

Commit 78922ed

Browse files
Add sdxl prompt embeddings (#3995)
* Add sdxl prompt embeddings * Fix more * fix some slow tests
1 parent 6fde5a6 commit 78922ed

File tree

5 files changed

+180
-69
lines changed

5 files changed

+180
-69
lines changed

src/diffusers/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
@dataclass
11-
# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with StableDiffusion->StableDiffusionXL
1211
class StableDiffusionXLPipelineOutput(BaseOutput):
1312
"""
1413
Output class for Stable Diffusion pipelines.
@@ -17,13 +16,9 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
1716
images (`List[PIL.Image.Image]` or `np.ndarray`)
1817
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
1918
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
20-
nsfw_content_detected (`List[bool]`)
21-
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
22-
(nsfw) content, or `None` if safety checking could not be performed.
2319
"""
2420

2521
images: Union[List[PIL.Image.Image], np.ndarray]
26-
nsfw_content_detected: Optional[List[bool]]
2722

2823

2924
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132

133133
self.watermark = StableDiffusionXLWatermarker()
134134

135+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
135136
def enable_vae_slicing(self):
136137
r"""
137138
Enable sliced VAE decoding.
@@ -141,13 +142,15 @@ def enable_vae_slicing(self):
141142
"""
142143
self.vae.enable_slicing()
143144

145+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
144146
def disable_vae_slicing(self):
145147
r"""
146148
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
147149
computing decoding in one step.
148150
"""
149151
self.vae.disable_slicing()
150152

153+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
151154
def enable_vae_tiling(self):
152155
r"""
153156
Enable tiled VAE decoding.
@@ -157,6 +160,7 @@ def enable_vae_tiling(self):
157160
"""
158161
self.vae.enable_tiling()
159162

163+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
160164
def disable_vae_tiling(self):
161165
r"""
162166
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
@@ -217,6 +221,7 @@ def enable_model_cpu_offload(self, gpu_id=0):
217221
self.final_offload_hook = hook
218222

219223
@property
224+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
220225
def _execution_device(self):
221226
r"""
222227
Returns the device on which the pipeline's models will be executed. After calling
@@ -237,12 +242,14 @@ def _execution_device(self):
237242
def encode_prompt(
238243
self,
239244
prompt,
240-
device,
241-
num_images_per_prompt,
242-
do_classifier_free_guidance,
245+
device: Optional[torch.device] = None,
246+
num_images_per_prompt: int = 1,
247+
do_classifier_free_guidance: bool = True,
243248
negative_prompt=None,
244249
prompt_embeds: Optional[torch.FloatTensor] = None,
245250
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
251+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
252+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
246253
lora_scale: Optional[float] = None,
247254
):
248255
r"""
@@ -268,9 +275,18 @@ def encode_prompt(
268275
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
269276
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
270277
argument.
278+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
279+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
280+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
281+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
282+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
284+
input argument.
271285
lora_scale (`float`, *optional*):
272286
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
273287
"""
288+
device = device or self._execution_device
289+
274290
# set lora scale so that monkey patched LoRA
275291
# function of text encoder can correctly access it
276292
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
@@ -399,6 +415,7 @@ def encode_prompt(
399415

400416
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
401417

418+
bs_embed = pooled_prompt_embeds.shape[0]
402419
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
403420
bs_embed * num_images_per_prompt, -1
404421
)
@@ -408,20 +425,7 @@ def encode_prompt(
408425

409426
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
410427

411-
def run_safety_checker(self, image, device, dtype):
412-
if self.safety_checker is None:
413-
has_nsfw_concept = None
414-
else:
415-
if torch.is_tensor(image):
416-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
417-
else:
418-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
419-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
420-
image, has_nsfw_concept = self.safety_checker(
421-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
422-
)
423-
return image, has_nsfw_concept
424-
428+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
425429
def prepare_extra_step_kwargs(self, generator, eta):
426430
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
427431
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -448,6 +452,8 @@ def check_inputs(
448452
negative_prompt=None,
449453
prompt_embeds=None,
450454
negative_prompt_embeds=None,
455+
pooled_prompt_embeds=None,
456+
negative_pooled_prompt_embeds=None,
451457
):
452458
if height % 8 != 0 or width % 8 != 0:
453459
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -486,6 +492,17 @@ def check_inputs(
486492
f" {negative_prompt_embeds.shape}."
487493
)
488494

495+
if prompt_embeds is not None and pooled_prompt_embeds is None:
496+
raise ValueError(
497+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
498+
)
499+
500+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
501+
raise ValueError(
502+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
503+
)
504+
505+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
489506
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
490507
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
491508
if isinstance(generator, list) and len(generator) != batch_size:
@@ -535,6 +552,8 @@ def __call__(
535552
latents: Optional[torch.FloatTensor] = None,
536553
prompt_embeds: Optional[torch.FloatTensor] = None,
537554
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
555+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
556+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
538557
output_type: Optional[str] = "pil",
539558
return_dict: bool = True,
540559
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -588,6 +607,13 @@ def __call__(
588607
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
589608
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
590609
argument.
610+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
611+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
612+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
613+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
614+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
615+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
616+
input argument.
591617
output_type (`str`, *optional*, defaults to `"pil"`):
592618
The output format of the generate image. Choose between
593619
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -634,7 +660,15 @@ def __call__(
634660

635661
# 1. Check inputs. Raise error if not correct
636662
self.check_inputs(
637-
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
663+
prompt,
664+
height,
665+
width,
666+
callback_steps,
667+
negative_prompt,
668+
prompt_embeds,
669+
negative_prompt_embeds,
670+
pooled_prompt_embeds,
671+
negative_pooled_prompt_embeds,
638672
)
639673

640674
# 2. Define call parameters
@@ -669,6 +703,8 @@ def __call__(
669703
negative_prompt,
670704
prompt_embeds=prompt_embeds,
671705
negative_prompt_embeds=negative_prompt_embeds,
706+
pooled_prompt_embeds=pooled_prompt_embeds,
707+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
672708
lora_scale=text_encoder_lora_scale,
673709
)
674710

@@ -765,27 +801,19 @@ def __call__(
765801
latents = latents.float()
766802

767803
if not output_type == "latent":
768-
# CHECK there is problem here (PVP)
769804
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
770-
has_nsfw_concept = None
771805
else:
772806
image = latents
773-
has_nsfw_concept = None
774-
return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected=None)
775-
776-
if has_nsfw_concept is None:
777-
do_denormalize = [True] * image.shape[0]
778-
else:
779-
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
807+
return StableDiffusionXLPipelineOutput(images=image)
780808

781809
image = self.watermark.apply_watermark(image)
782-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
810+
image = self.image_processor.postprocess(image, output_type=output_type)
783811

784812
# Offload last model to CPU
785813
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
786814
self.final_offload_hook.offload()
787815

788816
if not return_dict:
789-
return (image, has_nsfw_concept)
817+
return (image,)
790818

791-
return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
819+
return StableDiffusionXLPipelineOutput(images=image)

0 commit comments

Comments
 (0)