Skip to content

Refactor controlnet and add img2img and inpaint #3386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fd7712d
refactor controlnet and add img2img and inpaint
patrickvonplaten May 9, 2023
94e0d4f
change
patrickvonplaten May 11, 2023
7e0f6bc
First draft to get pipelines to work
patrickvonplaten May 11, 2023
c7d5221
make style
patrickvonplaten May 11, 2023
57fa09f
Fix more
patrickvonplaten May 12, 2023
9dd9e74
Fix more
patrickvonplaten May 12, 2023
edc6700
More tests
patrickvonplaten May 12, 2023
feb7b2b
Fix more
patrickvonplaten May 12, 2023
fc96b0c
Make inpainting work
patrickvonplaten May 12, 2023
0724b2f
make style and more tests
patrickvonplaten May 12, 2023
7d10292
Apply suggestions from code review
patrickvonplaten May 12, 2023
2ef2dbe
up
patrickvonplaten May 12, 2023
6c6a685
Merge branch 'add_inpaint_img2img_controlnet' of https://github.com/h…
patrickvonplaten May 12, 2023
f9cb03f
make style
patrickvonplaten May 12, 2023
9e95499
Fix imports
patrickvonplaten May 12, 2023
e58f216
Merge branch 'add_inpaint_img2img_controlnet' of https://github.com/h…
patrickvonplaten May 12, 2023
68ce349
Fix more
patrickvonplaten May 12, 2023
bbbd989
Fix more
patrickvonplaten May 12, 2023
2b965d9
Improve examples
patrickvonplaten May 12, 2023
ad9dc64
add test
patrickvonplaten May 12, 2023
4907c13
Make sure import is correctly deprecated
patrickvonplaten May 12, 2023
46c32d4
Merge branch 'main' into add_inpaint_img2img_controlnet
patrickvonplaten May 16, 2023
1dd63ca
Make sure everything works in compile mode
patrickvonplaten May 16, 2023
a2912b2
make sure authorship is correctly attributed
patrickvonplaten May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/cycle_diffusion
title: Cycle Diffusion
- local: api/pipelines/dance_diffusion
Expand Down Expand Up @@ -203,8 +205,6 @@
title: Self-Attention Guidance
- local: api/pipelines/stable_diffusion/panorama
title: MultiDiffusion Panorama
- local: api/pipelines/stable_diffusion/controlnet
title: Text-to-Image Generation with ControlNet Conditioning
- local: api/pipelines/stable_diffusion/model_editing
title: Text-to-Image Model Editing
- local: api/pipelines/stable_diffusion/diffedit
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
|---|---|:---:|:---:|
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
| [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The library has three main components:
|---|---|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
PaintByExamplePipeline,
SemanticStableDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works

from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
from .utils import deprecate


deprecate(
"pipelines_utils",
"0.22.0",
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
standard_warn=False,
stacklevel=3,
)
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
Expand All @@ -58,7 +63,6 @@
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionImageVariationPipeline,
Expand Down Expand Up @@ -133,8 +137,8 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .controlnet import FlaxStableDiffusionControlNetPipeline
from .stable_diffusion import (
FlaxStableDiffusionControlNetPipeline,
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/pipelines/controlnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_torch_available,
is_transformers_available,
)


try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline


if is_transformers_available() and is_flax_available():
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
66 changes: 66 additions & 0 deletions src/diffusers/pipelines/controlnet/multicontrolnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from ...models.controlnet import ControlNetModel, ControlNetOutput
from ...models.modeling_utils import ModelMixin


class MultiControlNetModel(ModelMixin):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet

This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.

Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""

def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
guess_mode,
return_dict,
)

# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample

return down_block_res_samples, mid_block_res_sample
Loading