Skip to content

Commit 01a8080

Browse files
Add semantic guidance pipeline (#2223)
* Add semantic guidance pipeline * Fix style * Refactor Pipeline * Pipeline documentation * Add documentation * Fix style and quality * Fix doctree * Add tests for SEGA * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Make compatible with half precision * Change deprecation warning to throw an exception * update --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 291ecda commit 01a8080

File tree

11 files changed

+1435
-2
lines changed

11 files changed

+1435
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@
134134
title: Safe Stable Diffusion
135135
- local: api/pipelines/score_sde_ve
136136
title: Score SDE VE
137+
- local: api/pipelines/semantic_stable_diffusion
138+
title: Semantic Guidance
137139
- sections:
138140
- local: api/pipelines/stable_diffusion/overview
139141
title: Overview
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Semantic Guidance
14+
15+
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247) and provides strong semantic control over the image generation.
16+
Small changes to the text prompt usually result in entirely different output images. However, with SEGA a variety of changes to the image are enabled that can be controlled easily and intuitively, and stay true to the original image composition.
17+
18+
The abstract of the paper is the following:
19+
20+
*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on a variety of tasks and provide evidence for its versatility and flexibility.*
21+
22+
23+
*Overview*:
24+
25+
| Pipeline | Tasks | Colab | Demo
26+
|---|---|:---:|:---:|
27+
| [pipeline_semantic_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) | [Coming Soon](https://huggingface.co/AIML-TUDA)
28+
29+
## Tips
30+
31+
- The Semantic Guidance pipeline can be used with any [Stable Diffusion](./api/pipelines/stable_diffusion/text2img) checkpoint.
32+
33+
### Run Semantic Guidance
34+
35+
The interface of [`SemanticStableDiffusionPipeline`] provides several additional parameters to influence the image generation.
36+
Exemplary usage may look like this:
37+
38+
```python
39+
import torch
40+
from diffusers import SemanticStableDiffusionPipeline
41+
42+
pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
43+
pipe = pipe.to("cuda")
44+
45+
out = pipe(
46+
prompt="a photo of the face of a woman",
47+
num_images_per_prompt=1,
48+
guidance_scale=7,
49+
editing_prompt=[
50+
"smiling, smile", # Concepts to apply
51+
"glasses, wearing glasses",
52+
"curls, wavy hair, curly hair",
53+
"beard, full beard, mustache",
54+
],
55+
reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts
56+
edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept
57+
edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept
58+
edit_threshold=[
59+
0.99,
60+
0.975,
61+
0.925,
62+
0.96,
63+
], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions
64+
edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance
65+
edit_mom_beta=0.6, # Momentum beta
66+
edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other
67+
)
68+
```
69+
70+
For more examples check the colab notebook.
71+
72+
## StableDiffusionSafePipelineOutput
73+
[[autodoc]] pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput
74+
- all
75+
76+
## SemanticStableDiffusionPipeline
77+
[[autodoc]] SemanticStableDiffusionPipeline
78+
- all
79+
- __call__

docs/source/en/api/pipelines/stable_diffusion_safe.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The abstract of the paper is the following:
2424

2525
| Pipeline | Tasks | Colab | Demo
2626
|---|---|:---:|:---:|
27-
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
27+
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [![Huggingface Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/AIML-TUDA/unsafe-vs-safe-stable-diffusion)
2828

2929
## Tips
3030

@@ -58,7 +58,7 @@ You may use the 4 configurations defined in the [Safe Latent Diffusion paper](ht
5858
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
5959
```
6060

61-
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
61+
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`.
6262

6363
### How to load and use different schedulers.
6464

docs/source/en/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ available a colab notebook to directly try them out.
4747
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
4848
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
4949
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
50+
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [**Semantic Guidance**](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb)
5051
| [stable_diffusion](./api/pipelines/stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
5152
| [stable_diffusion](./api/pipelines/stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
5253
| [stable_diffusion](./api/pipelines/stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
CycleDiffusionPipeline,
111111
LDMTextToImagePipeline,
112112
PaintByExamplePipeline,
113+
SemanticStableDiffusionPipeline,
113114
StableDiffusionAttendAndExcitePipeline,
114115
StableDiffusionDepth2ImgPipeline,
115116
StableDiffusionImageVariationPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
4545
from .latent_diffusion import LDMTextToImagePipeline
4646
from .paint_by_example import PaintByExamplePipeline
47+
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
4748
from .stable_diffusion import (
4849
CycleDiffusionPipeline,
4950
StableDiffusionAttendAndExcitePipeline,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Optional, Union
4+
5+
import numpy as np
6+
import PIL
7+
from PIL import Image
8+
9+
from ...utils import BaseOutput, is_torch_available, is_transformers_available
10+
11+
12+
@dataclass
13+
class SemanticStableDiffusionPipelineOutput(BaseOutput):
14+
"""
15+
Output class for Stable Diffusion pipelines.
16+
17+
Args:
18+
images (`List[PIL.Image.Image]` or `np.ndarray`)
19+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
20+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
21+
nsfw_content_detected (`List[bool]`)
22+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
23+
(nsfw) content, or `None` if safety checking could not be performed.
24+
"""
25+
26+
images: Union[List[PIL.Image.Image], np.ndarray]
27+
nsfw_content_detected: Optional[List[bool]]
28+
29+
30+
if is_transformers_available() and is_torch_available():
31+
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline

0 commit comments

Comments
 (0)