|
27 | 27 |
|
28 | 28 | from ...models import UNet2DConditionModel
|
29 | 29 | from ...pipelines import DiffusionPipeline
|
| 30 | +from ...pipelines.pipeline_utils import ImagePipelineOutput |
30 | 31 | from ...schedulers import UnCLIPScheduler
|
31 | 32 | from ...utils import (
|
32 | 33 | is_accelerate_available,
|
@@ -382,11 +383,11 @@ def __call__(
|
382 | 383 | #num_channels_latents = self.image_encoder.config.z_channels
|
383 | 384 |
|
384 | 385 | # get h, w for latents
|
385 |
| - height, width = get_new_h_w(height, width) |
| 386 | + sample_height, sample_width = get_new_h_w(height, width) |
386 | 387 |
|
387 | 388 | # create initial latent
|
388 | 389 | latents = self.prepare_latents(
|
389 |
| - (batch_size, num_channels_latents, height, width), |
| 390 | + (batch_size, num_channels_latents, sample_height, sample_width), |
390 | 391 | text_encoder_hidden_states.dtype,
|
391 | 392 | device,
|
392 | 393 | generator,
|
@@ -448,4 +449,17 @@ def __call__(
|
448 | 449 |
|
449 | 450 | _, latents = latents.chunk(2)
|
450 | 451 |
|
451 |
| - return latents |
| 452 | + # post-processing |
| 453 | + image = self.image_encoder.decode(latents) |
| 454 | + |
| 455 | + image = image * 0.5 + 0.5 |
| 456 | + image = image.clamp(0, 1) |
| 457 | + image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 458 | + |
| 459 | + if output_type == "pil": |
| 460 | + image = self.numpy_to_pil(image) |
| 461 | + |
| 462 | + if not return_dict: |
| 463 | + return (image,) |
| 464 | + |
| 465 | + return ImagePipelineOutput(images=image) |
0 commit comments