@@ -132,6 +132,7 @@ def __init__(
132
132
133
133
self .watermark = StableDiffusionXLWatermarker ()
134
134
135
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
135
136
def enable_vae_slicing (self ):
136
137
r"""
137
138
Enable sliced VAE decoding.
@@ -141,13 +142,15 @@ def enable_vae_slicing(self):
141
142
"""
142
143
self .vae .enable_slicing ()
143
144
145
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
144
146
def disable_vae_slicing (self ):
145
147
r"""
146
148
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
147
149
computing decoding in one step.
148
150
"""
149
151
self .vae .disable_slicing ()
150
152
153
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
151
154
def enable_vae_tiling (self ):
152
155
r"""
153
156
Enable tiled VAE decoding.
@@ -157,6 +160,7 @@ def enable_vae_tiling(self):
157
160
"""
158
161
self .vae .enable_tiling ()
159
162
163
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
160
164
def disable_vae_tiling (self ):
161
165
r"""
162
166
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
@@ -217,6 +221,7 @@ def enable_model_cpu_offload(self, gpu_id=0):
217
221
self .final_offload_hook = hook
218
222
219
223
@property
224
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
220
225
def _execution_device (self ):
221
226
r"""
222
227
Returns the device on which the pipeline's models will be executed. After calling
@@ -237,12 +242,14 @@ def _execution_device(self):
237
242
def encode_prompt (
238
243
self ,
239
244
prompt ,
240
- device ,
241
- num_images_per_prompt ,
242
- do_classifier_free_guidance ,
245
+ device : Optional [ torch . device ] = None ,
246
+ num_images_per_prompt : int = 1 ,
247
+ do_classifier_free_guidance : bool = True ,
243
248
negative_prompt = None ,
244
249
prompt_embeds : Optional [torch .FloatTensor ] = None ,
245
250
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
251
+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
252
+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
246
253
lora_scale : Optional [float ] = None ,
247
254
):
248
255
r"""
@@ -268,9 +275,18 @@ def encode_prompt(
268
275
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
269
276
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
270
277
argument.
278
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
280
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
281
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
282
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
284
+ input argument.
271
285
lora_scale (`float`, *optional*):
272
286
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
273
287
"""
288
+ device = device or self ._execution_device
289
+
274
290
# set lora scale so that monkey patched LoRA
275
291
# function of text encoder can correctly access it
276
292
if lora_scale is not None and isinstance (self , LoraLoaderMixin ):
@@ -399,6 +415,7 @@ def encode_prompt(
399
415
400
416
negative_prompt_embeds = torch .concat (negative_prompt_embeds_list , dim = - 1 )
401
417
418
+ bs_embed = pooled_prompt_embeds .shape [0 ]
402
419
pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt ).view (
403
420
bs_embed * num_images_per_prompt , - 1
404
421
)
@@ -408,20 +425,7 @@ def encode_prompt(
408
425
409
426
return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
410
427
411
- def run_safety_checker (self , image , device , dtype ):
412
- if self .safety_checker is None :
413
- has_nsfw_concept = None
414
- else :
415
- if torch .is_tensor (image ):
416
- feature_extractor_input = self .image_processor .postprocess (image , output_type = "pil" )
417
- else :
418
- feature_extractor_input = self .image_processor .numpy_to_pil (image )
419
- safety_checker_input = self .feature_extractor (feature_extractor_input , return_tensors = "pt" ).to (device )
420
- image , has_nsfw_concept = self .safety_checker (
421
- images = image , clip_input = safety_checker_input .pixel_values .to (dtype )
422
- )
423
- return image , has_nsfw_concept
424
-
428
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
425
429
def prepare_extra_step_kwargs (self , generator , eta ):
426
430
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
427
431
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -448,6 +452,8 @@ def check_inputs(
448
452
negative_prompt = None ,
449
453
prompt_embeds = None ,
450
454
negative_prompt_embeds = None ,
455
+ pooled_prompt_embeds = None ,
456
+ negative_pooled_prompt_embeds = None ,
451
457
):
452
458
if height % 8 != 0 or width % 8 != 0 :
453
459
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -486,6 +492,17 @@ def check_inputs(
486
492
f" { negative_prompt_embeds .shape } ."
487
493
)
488
494
495
+ if prompt_embeds is not None and pooled_prompt_embeds is None :
496
+ raise ValueError (
497
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
498
+ )
499
+
500
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
501
+ raise ValueError (
502
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
503
+ )
504
+
505
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
489
506
def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
490
507
shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
491
508
if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -535,6 +552,8 @@ def __call__(
535
552
latents : Optional [torch .FloatTensor ] = None ,
536
553
prompt_embeds : Optional [torch .FloatTensor ] = None ,
537
554
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
555
+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
556
+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
538
557
output_type : Optional [str ] = "pil" ,
539
558
return_dict : bool = True ,
540
559
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -588,6 +607,13 @@ def __call__(
588
607
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
589
608
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
590
609
argument.
610
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
611
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
612
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
613
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
614
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
615
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
616
+ input argument.
591
617
output_type (`str`, *optional*, defaults to `"pil"`):
592
618
The output format of the generate image. Choose between
593
619
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -634,7 +660,15 @@ def __call__(
634
660
635
661
# 1. Check inputs. Raise error if not correct
636
662
self .check_inputs (
637
- prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
663
+ prompt ,
664
+ height ,
665
+ width ,
666
+ callback_steps ,
667
+ negative_prompt ,
668
+ prompt_embeds ,
669
+ negative_prompt_embeds ,
670
+ pooled_prompt_embeds ,
671
+ negative_pooled_prompt_embeds ,
638
672
)
639
673
640
674
# 2. Define call parameters
@@ -669,6 +703,8 @@ def __call__(
669
703
negative_prompt ,
670
704
prompt_embeds = prompt_embeds ,
671
705
negative_prompt_embeds = negative_prompt_embeds ,
706
+ pooled_prompt_embeds = pooled_prompt_embeds ,
707
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
672
708
lora_scale = text_encoder_lora_scale ,
673
709
)
674
710
@@ -765,27 +801,19 @@ def __call__(
765
801
latents = latents .float ()
766
802
767
803
if not output_type == "latent" :
768
- # CHECK there is problem here (PVP)
769
804
image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
770
- has_nsfw_concept = None
771
805
else :
772
806
image = latents
773
- has_nsfw_concept = None
774
- return StableDiffusionXLPipelineOutput (images = image , nsfw_content_detected = None )
775
-
776
- if has_nsfw_concept is None :
777
- do_denormalize = [True ] * image .shape [0 ]
778
- else :
779
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
807
+ return StableDiffusionXLPipelineOutput (images = image )
780
808
781
809
image = self .watermark .apply_watermark (image )
782
- image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
810
+ image = self .image_processor .postprocess (image , output_type = output_type )
783
811
784
812
# Offload last model to CPU
785
813
if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
786
814
self .final_offload_hook .offload ()
787
815
788
816
if not return_dict :
789
- return (image , has_nsfw_concept )
817
+ return (image ,)
790
818
791
- return StableDiffusionXLPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
819
+ return StableDiffusionXLPipelineOutput (images = image )
0 commit comments