Skip to content

Commit c559479

Browse files
yiyixuxuyiyixuxu
and
yiyixuxu
authored
Postprocessing refactor all others (#3337)
* add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent a757b2d commit c559479

File tree

48 files changed

+669
-302
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+669
-302
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import warnings
1617
from typing import Any, Callable, Dict, List, Optional, Union
1718

1819
import torch
@@ -22,6 +23,7 @@
2223
from diffusers.utils import is_accelerate_available, is_accelerate_version
2324

2425
from ...configuration_utils import FrozenDict
26+
from ...image_processor import VaeImageProcessor
2527
from ...loaders import TextualInversionLoaderMixin
2628
from ...models import AutoencoderKL, UNet2DConditionModel
2729
from ...schedulers import KarrasDiffusionSchedulers
@@ -174,6 +176,7 @@ def __init__(
174176
feature_extractor=feature_extractor,
175177
)
176178
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
179+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
177180
self.register_to_config(requires_safety_checker=requires_safety_checker)
178181

179182
def enable_vae_slicing(self):
@@ -426,16 +429,27 @@ def _encode_prompt(
426429
return prompt_embeds
427430

428431
def run_safety_checker(self, image, device, dtype):
429-
if self.safety_checker is not None:
430-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
432+
if self.safety_checker is None:
433+
has_nsfw_concept = None
434+
else:
435+
if torch.is_tensor(image):
436+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
437+
else:
438+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
439+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
431440
image, has_nsfw_concept = self.safety_checker(
432441
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
433442
)
434-
else:
435-
has_nsfw_concept = None
436443
return image, has_nsfw_concept
437444

438445
def decode_latents(self, latents):
446+
warnings.warn(
447+
(
448+
"The decode_latents method is deprecated and will be removed in a future version. Please"
449+
" use VaeImageProcessor instead"
450+
),
451+
FutureWarning,
452+
)
439453
latents = 1 / self.vae.config.scaling_factor * latents
440454
image = self.vae.decode(latents, return_dict=False)[0]
441455
image = (image / 2 + 0.5).clamp(0, 1)
@@ -700,24 +714,19 @@ def __call__(
700714
if callback is not None and i % callback_steps == 0:
701715
callback(i, t, latents)
702716

703-
if output_type == "latent":
717+
if not output_type == "latent":
718+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
719+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
720+
else:
704721
image = latents
705722
has_nsfw_concept = None
706-
elif output_type == "pil":
707-
# 8. Post-processing
708-
image = self.decode_latents(latents)
709723

710-
# 9. Run safety checker
711-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
712-
713-
# 10. Convert to PIL
714-
image = self.numpy_to_pil(image)
724+
if has_nsfw_concept is None:
725+
do_denormalize = [True] * image.shape[0]
715726
else:
716-
# 8. Post-processing
717-
image = self.decode_latents(latents)
727+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
718728

719-
# 9. Run safety checker
720-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
729+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
721730

722731
# Offload last model to CPU
723732
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import warnings
1617
from typing import Callable, List, Optional, Union
1718

1819
import numpy as np
@@ -22,6 +23,7 @@
2223

2324
from diffusers.utils import is_accelerate_available
2425

26+
from ...image_processor import VaeImageProcessor
2527
from ...models import AutoencoderKL, UNet2DConditionModel
2628
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
2729
from ...utils import logging, randn_tensor
@@ -184,6 +186,7 @@ def __init__(
184186
feature_extractor=feature_extractor,
185187
)
186188
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
189+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
187190
self.register_to_config(requires_safety_checker=requires_safety_checker)
188191

189192
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -226,13 +229,17 @@ def _execution_device(self):
226229

227230
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
228231
def run_safety_checker(self, image, device, dtype):
229-
if self.safety_checker is not None:
230-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
232+
if self.safety_checker is None:
233+
has_nsfw_concept = None
234+
else:
235+
if torch.is_tensor(image):
236+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
237+
else:
238+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
239+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
231240
image, has_nsfw_concept = self.safety_checker(
232241
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
233242
)
234-
else:
235-
has_nsfw_concept = None
236243
return image, has_nsfw_concept
237244

238245
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -255,6 +262,11 @@ def prepare_extra_step_kwargs(self, generator, eta):
255262

256263
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
257264
def decode_latents(self, latents):
265+
warnings.warn(
266+
"The decode_latents method is deprecated and will be removed in a future version. Please"
267+
" use VaeImageProcessor instead",
268+
FutureWarning,
269+
)
258270
latents = 1 / self.vae.config.scaling_factor * latents
259271
image = self.vae.decode(latents, return_dict=False)[0]
260272
image = (image / 2 + 0.5).clamp(0, 1)
@@ -560,15 +572,19 @@ def __call__(
560572
if callback is not None and i % callback_steps == 0:
561573
callback(i, t, latents)
562574

563-
# 11. Post-processing
564-
image = self.decode_latents(latents)
575+
if not output_type == "latent":
576+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
577+
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
578+
else:
579+
image = latents
580+
has_nsfw_concept = None
565581

566-
# 12. Run safety checker
567-
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
582+
if has_nsfw_concept is None:
583+
do_denormalize = [True] * image.shape[0]
584+
else:
585+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
568586

569-
# 13. Convert to PIL
570-
if output_type == "pil":
571-
image = self.numpy_to_pil(image)
587+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
572588

573589
if not return_dict:
574590
return (image, has_nsfw_concept)

src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import inspect
2+
import warnings
23
from itertools import repeat
34
from typing import Callable, List, Optional, Union
45

56
import torch
67
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
78

9+
from ...image_processor import VaeImageProcessor
810
from ...models import AutoencoderKL, UNet2DConditionModel
911
from ...pipeline_utils import DiffusionPipeline
1012
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -129,10 +131,31 @@ def __init__(
129131
feature_extractor=feature_extractor,
130132
)
131133
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
134+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
132135
self.register_to_config(requires_safety_checker=requires_safety_checker)
133136

137+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
138+
def run_safety_checker(self, image, device, dtype):
139+
if self.safety_checker is None:
140+
has_nsfw_concept = None
141+
else:
142+
if torch.is_tensor(image):
143+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
144+
else:
145+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
146+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
147+
image, has_nsfw_concept = self.safety_checker(
148+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
149+
)
150+
return image, has_nsfw_concept
151+
134152
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
135153
def decode_latents(self, latents):
154+
warnings.warn(
155+
"The decode_latents method is deprecated and will be removed in a future version. Please"
156+
" use VaeImageProcessor instead",
157+
FutureWarning,
158+
)
136159
latents = 1 / self.vae.config.scaling_factor * latents
137160
image = self.vae.decode(latents, return_dict=False)[0]
138161
image = (image / 2 + 0.5).clamp(0, 1)
@@ -681,20 +704,19 @@ def __call__(
681704
callback(i, t, latents)
682705

683706
# 8. Post-processing
684-
image = self.decode_latents(latents)
685-
686-
if self.safety_checker is not None:
687-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
688-
self.device
689-
)
690-
image, has_nsfw_concept = self.safety_checker(
691-
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
692-
)
707+
if not output_type == "latent":
708+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
709+
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
693710
else:
711+
image = latents
694712
has_nsfw_concept = None
695713

696-
if output_type == "pil":
697-
image = self.numpy_to_pil(image)
714+
if has_nsfw_concept is None:
715+
do_denormalize = [True] * image.shape[0]
716+
else:
717+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
718+
719+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
698720

699721
if not return_dict:
700722
return (image, has_nsfw_concept)

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import warnings
1617
from typing import Callable, List, Optional, Union
1718

1819
import numpy as np
@@ -24,6 +25,7 @@
2425
from diffusers.utils import is_accelerate_available, is_accelerate_version
2526

2627
from ...configuration_utils import FrozenDict
28+
from ...image_processor import VaeImageProcessor
2729
from ...loaders import TextualInversionLoaderMixin
2830
from ...models import AutoencoderKL, UNet2DConditionModel
2931
from ...schedulers import DDIMScheduler
@@ -220,6 +222,8 @@ def __init__(
220222
safety_checker=safety_checker,
221223
feature_extractor=feature_extractor,
222224
)
225+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
226+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
223227
self.register_to_config(requires_safety_checker=requires_safety_checker)
224228

225229
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
@@ -504,17 +508,26 @@ def prepare_extra_step_kwargs(self, generator, eta):
504508

505509
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
506510
def run_safety_checker(self, image, device, dtype):
507-
if self.safety_checker is not None:
508-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
511+
if self.safety_checker is None:
512+
has_nsfw_concept = None
513+
else:
514+
if torch.is_tensor(image):
515+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
516+
else:
517+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
518+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
509519
image, has_nsfw_concept = self.safety_checker(
510520
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
511521
)
512-
else:
513-
has_nsfw_concept = None
514522
return image, has_nsfw_concept
515523

516524
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
517525
def decode_latents(self, latents):
526+
warnings.warn(
527+
"The decode_latents method is deprecated and will be removed in a future version. Please"
528+
" use VaeImageProcessor instead",
529+
FutureWarning,
530+
)
518531
latents = 1 / self.vae.config.scaling_factor * latents
519532
image = self.vae.decode(latents, return_dict=False)[0]
520533
image = (image / 2 + 0.5).clamp(0, 1)
@@ -770,14 +783,19 @@ def __call__(
770783
callback(i, t, latents)
771784

772785
# 9. Post-processing
773-
image = self.decode_latents(latents)
786+
if not output_type == "latent":
787+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
788+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
789+
else:
790+
image = latents
791+
has_nsfw_concept = None
774792

775-
# 10. Run safety checker
776-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
793+
if has_nsfw_concept is None:
794+
do_denormalize = [True] * image.shape[0]
795+
else:
796+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
777797

778-
# 11. Convert to PIL
779-
if output_type == "pil":
780-
image = self.numpy_to_pil(image)
798+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
781799

782800
if not return_dict:
783801
return (image, has_nsfw_concept)

0 commit comments

Comments
 (0)