Skip to content

Commit 4aee722

Browse files
Add semantic guidance pipeline (huggingface#2223)
* Add semantic guidance pipeline * Fix style * Refactor Pipeline * Pipeline documentation * Add documentation * Fix style and quality * Fix doctree * Add tests for SEGA * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Make compatible with half precision * Change deprecation warning to throw an exception * update --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent b1edd07 commit 4aee722

File tree

5 files changed

+750
-0
lines changed

5 files changed

+750
-0
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
CycleDiffusionPipeline,
111111
LDMTextToImagePipeline,
112112
PaintByExamplePipeline,
113+
SemanticStableDiffusionPipeline,
113114
StableDiffusionAttendAndExcitePipeline,
114115
StableDiffusionDepth2ImgPipeline,
115116
StableDiffusionImageVariationPipeline,

pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
4545
from .latent_diffusion import LDMTextToImagePipeline
4646
from .paint_by_example import PaintByExamplePipeline
47+
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
4748
from .stable_diffusion import (
4849
CycleDiffusionPipeline,
4950
StableDiffusionAttendAndExcitePipeline,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Optional, Union
4+
5+
import numpy as np
6+
import PIL
7+
from PIL import Image
8+
9+
from ...utils import BaseOutput, is_torch_available, is_transformers_available
10+
11+
12+
@dataclass
13+
class SemanticStableDiffusionPipelineOutput(BaseOutput):
14+
"""
15+
Output class for Stable Diffusion pipelines.
16+
17+
Args:
18+
images (`List[PIL.Image.Image]` or `np.ndarray`)
19+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
20+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
21+
nsfw_content_detected (`List[bool]`)
22+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
23+
(nsfw) content, or `None` if safety checking could not be performed.
24+
"""
25+
26+
images: Union[List[PIL.Image.Image], np.ndarray]
27+
nsfw_content_detected: Optional[List[bool]]
28+
29+
30+
if is_transformers_available() and is_torch_available():
31+
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline

0 commit comments

Comments
 (0)