Skip to content

Commit 0272da9

Browse files
author
yiyixuxu
committed
add post-processing to inpaint
1 parent 28eb816 commit 0272da9

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from ...models import UNet2DConditionModel
2929
from ...pipelines import DiffusionPipeline
30+
from ...pipelines.pipeline_utils import ImagePipelineOutput
3031
from ...schedulers import UnCLIPScheduler
3132
from ...utils import (
3233
is_accelerate_available,
@@ -382,11 +383,11 @@ def __call__(
382383
#num_channels_latents = self.image_encoder.config.z_channels
383384

384385
# get h, w for latents
385-
height, width = get_new_h_w(height, width)
386+
sample_height, sample_width = get_new_h_w(height, width)
386387

387388
# create initial latent
388389
latents = self.prepare_latents(
389-
(batch_size, num_channels_latents, height, width),
390+
(batch_size, num_channels_latents, sample_height, sample_width),
390391
text_encoder_hidden_states.dtype,
391392
device,
392393
generator,
@@ -448,4 +449,17 @@ def __call__(
448449

449450
_, latents = latents.chunk(2)
450451

451-
return latents
452+
# post-processing
453+
image = self.image_encoder.decode(latents)
454+
455+
image = image * 0.5 + 0.5
456+
image = image.clamp(0, 1)
457+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
458+
459+
if output_type == "pil":
460+
image = self.numpy_to_pil(image)
461+
462+
if not return_dict:
463+
return (image,)
464+
465+
return ImagePipelineOutput(images=image)

0 commit comments

Comments
 (0)