Skip to content

Commit 886575e

Browse files
Refactor controlnet and add img2img and inpaint (#3386)
* refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed
1 parent 9d44e2f commit 886575e

25 files changed

+4878
-1644
lines changed

docs/source/en/_toctree.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@
148148
title: Audio Diffusion
149149
- local: api/pipelines/audioldm
150150
title: AudioLDM
151+
- local: api/pipelines/controlnet
152+
title: ControlNet
151153
- local: api/pipelines/cycle_diffusion
152154
title: Cycle Diffusion
153155
- local: api/pipelines/dance_diffusion
@@ -203,8 +205,6 @@
203205
title: Self-Attention Guidance
204206
- local: api/pipelines/stable_diffusion/panorama
205207
title: MultiDiffusion Panorama
206-
- local: api/pipelines/stable_diffusion/controlnet
207-
title: Text-to-Image Generation with ControlNet Conditioning
208208
- local: api/pipelines/stable_diffusion/model_editing
209209
title: Text-to-Image Model Editing
210210
- local: api/pipelines/stable_diffusion/diffedit

docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx renamed to docs/source/en/api/pipelines/controlnet.mdx

+44-17
Large diffs are not rendered by default.

docs/source/en/api/pipelines/overview.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
4646
|---|---|:---:|:---:|
4747
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
4848
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
49-
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
49+
| [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
5050
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
5151
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
5252
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |

docs/source/en/index.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ The library has three main components:
5353
|---|---|:---:|
5454
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
5555
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
56-
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
56+
| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
5757
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
5858
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
5959
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |

src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132
PaintByExamplePipeline,
133133
SemanticStableDiffusionPipeline,
134134
StableDiffusionAttendAndExcitePipeline,
135+
StableDiffusionControlNetImg2ImgPipeline,
136+
StableDiffusionControlNetInpaintPipeline,
135137
StableDiffusionControlNetPipeline,
136138
StableDiffusionDepth2ImgPipeline,
137139
StableDiffusionDiffEditPipeline,

src/diffusers/pipeline_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,13 @@
1717
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
1818

1919
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
20+
from .utils import deprecate
21+
22+
23+
deprecate(
24+
"pipelines_utils",
25+
"0.22.0",
26+
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
27+
standard_warn=False,
28+
stacklevel=3,
29+
)

src/diffusers/pipelines/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
else:
4545
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
4646
from .audioldm import AudioLDMPipeline
47+
from .controlnet import (
48+
StableDiffusionControlNetImg2ImgPipeline,
49+
StableDiffusionControlNetInpaintPipeline,
50+
StableDiffusionControlNetPipeline,
51+
)
4752
from .deepfloyd_if import (
4853
IFImg2ImgPipeline,
4954
IFImg2ImgSuperResolutionPipeline,
@@ -58,7 +63,6 @@
5863
from .stable_diffusion import (
5964
CycleDiffusionPipeline,
6065
StableDiffusionAttendAndExcitePipeline,
61-
StableDiffusionControlNetPipeline,
6266
StableDiffusionDepth2ImgPipeline,
6367
StableDiffusionDiffEditPipeline,
6468
StableDiffusionImageVariationPipeline,
@@ -133,8 +137,8 @@
133137
except OptionalDependencyNotAvailable:
134138
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
135139
else:
140+
from .controlnet import FlaxStableDiffusionControlNetPipeline
136141
from .stable_diffusion import (
137-
FlaxStableDiffusionControlNetPipeline,
138142
FlaxStableDiffusionImg2ImgPipeline,
139143
FlaxStableDiffusionInpaintPipeline,
140144
FlaxStableDiffusionPipeline,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from ...utils import (
2+
OptionalDependencyNotAvailable,
3+
is_flax_available,
4+
is_torch_available,
5+
is_transformers_available,
6+
)
7+
8+
9+
try:
10+
if not (is_transformers_available() and is_torch_available()):
11+
raise OptionalDependencyNotAvailable()
12+
except OptionalDependencyNotAvailable:
13+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
14+
else:
15+
from .multicontrolnet import MultiControlNetModel
16+
from .pipeline_controlnet import StableDiffusionControlNetPipeline
17+
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
18+
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
19+
20+
21+
if is_transformers_available() and is_flax_available():
22+
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Any, Dict, List, Optional, Tuple, Union
2+
3+
import torch
4+
from torch import nn
5+
6+
from ...models.controlnet import ControlNetModel, ControlNetOutput
7+
from ...models.modeling_utils import ModelMixin
8+
9+
10+
class MultiControlNetModel(ModelMixin):
11+
r"""
12+
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
13+
14+
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
15+
compatible with `ControlNetModel`.
16+
17+
Args:
18+
controlnets (`List[ControlNetModel]`):
19+
Provides additional conditioning to the unet during the denoising process. You must set multiple
20+
`ControlNetModel` as a list.
21+
"""
22+
23+
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
24+
super().__init__()
25+
self.nets = nn.ModuleList(controlnets)
26+
27+
def forward(
28+
self,
29+
sample: torch.FloatTensor,
30+
timestep: Union[torch.Tensor, float, int],
31+
encoder_hidden_states: torch.Tensor,
32+
controlnet_cond: List[torch.tensor],
33+
conditioning_scale: List[float],
34+
class_labels: Optional[torch.Tensor] = None,
35+
timestep_cond: Optional[torch.Tensor] = None,
36+
attention_mask: Optional[torch.Tensor] = None,
37+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
38+
guess_mode: bool = False,
39+
return_dict: bool = True,
40+
) -> Union[ControlNetOutput, Tuple]:
41+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
42+
down_samples, mid_sample = controlnet(
43+
sample,
44+
timestep,
45+
encoder_hidden_states,
46+
image,
47+
scale,
48+
class_labels,
49+
timestep_cond,
50+
attention_mask,
51+
cross_attention_kwargs,
52+
guess_mode,
53+
return_dict,
54+
)
55+
56+
# merge samples
57+
if i == 0:
58+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
59+
else:
60+
down_block_res_samples = [
61+
samples_prev + samples_curr
62+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
63+
]
64+
mid_block_res_sample += mid_sample
65+
66+
return down_block_res_samples, mid_block_res_sample

0 commit comments

Comments
 (0)