Skip to content

Commit c3e3a1e

Browse files
patrickvonplatensayakpaul
authored andcommitted
[SDXL] Make watermarker optional under certain circumstances to improve usability of SDXL 1.0 (#4346)
* improve sdxl * more fixes * improve sdxl * improve sdxl * improve sdxl * finish
1 parent 9cde56a commit c3e3a1e

15 files changed

+202
-137
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,25 @@ You can install the libraries as follows:
3838
pip install transformers
3939
pip install accelerate
4040
pip install safetensors
41+
```
42+
43+
### Watermarker
44+
45+
We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
46+
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:
47+
48+
```
4149
pip install invisible-watermark>=0.2.0
4250
```
4351

52+
If the `invisible-watermark` library is installed the watermarker will be used **by default**.
53+
54+
If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:
55+
56+
```py
57+
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
58+
```
59+
4460
### Text-to-Image
4561

4662
You can use SDXL as follows for *text-to-image*:

examples/controlnet/requirements_sdxl.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ transformers>=4.25.1
44
ftfy
55
tensorboard
66
Jinja2
7-
invisible-watermark>=0.2.0
87
datasets
98
wandb

examples/dreambooth/requirements_sdxl.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@ transformers>=4.25.1
44
ftfy
55
tensorboard
66
Jinja2
7-
invisible-watermark>=0.2.0

src/diffusers/__init__.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@
185185
StableDiffusionPix2PixZeroPipeline,
186186
StableDiffusionSAGPipeline,
187187
StableDiffusionUpscalePipeline,
188+
StableDiffusionXLControlNetPipeline,
189+
StableDiffusionXLImg2ImgPipeline,
190+
StableDiffusionXLInpaintPipeline,
191+
StableDiffusionXLInstructPix2PixPipeline,
192+
StableDiffusionXLPipeline,
188193
StableUnCLIPImg2ImgPipeline,
189194
StableUnCLIPPipeline,
190195
TextToVideoSDPipeline,
@@ -202,20 +207,6 @@
202207
VQDiffusionPipeline,
203208
)
204209

205-
try:
206-
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
207-
raise OptionalDependencyNotAvailable()
208-
except OptionalDependencyNotAvailable:
209-
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
210-
else:
211-
from .pipelines import (
212-
StableDiffusionXLControlNetPipeline,
213-
StableDiffusionXLImg2ImgPipeline,
214-
StableDiffusionXLInpaintPipeline,
215-
StableDiffusionXLInstructPix2PixPipeline,
216-
StableDiffusionXLPipeline,
217-
)
218-
219210
try:
220211
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
221212
raise OptionalDependencyNotAvailable()

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from ..utils import (
22
OptionalDependencyNotAvailable,
33
is_flax_available,
4-
is_invisible_watermark_available,
54
is_k_diffusion_available,
65
is_librosa_available,
76
is_note_seq_available,
@@ -51,6 +50,7 @@
5150
StableDiffusionControlNetImg2ImgPipeline,
5251
StableDiffusionControlNetInpaintPipeline,
5352
StableDiffusionControlNetPipeline,
53+
StableDiffusionXLControlNetPipeline,
5454
)
5555
from .deepfloyd_if import (
5656
IFImg2ImgPipeline,
@@ -108,6 +108,12 @@
108108
StableUnCLIPPipeline,
109109
)
110110
from .stable_diffusion_safe import StableDiffusionPipelineSafe
111+
from .stable_diffusion_xl import (
112+
StableDiffusionXLImg2ImgPipeline,
113+
StableDiffusionXLInpaintPipeline,
114+
StableDiffusionXLInstructPix2PixPipeline,
115+
StableDiffusionXLPipeline,
116+
)
111117
from .t2i_adapter import StableDiffusionAdapterPipeline
112118
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
113119
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
@@ -121,20 +127,6 @@
121127
from .vq_diffusion import VQDiffusionPipeline
122128

123129

124-
try:
125-
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
126-
raise OptionalDependencyNotAvailable()
127-
except OptionalDependencyNotAvailable:
128-
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
129-
else:
130-
from .controlnet import StableDiffusionXLControlNetPipeline
131-
from .stable_diffusion_xl import (
132-
StableDiffusionXLImg2ImgPipeline,
133-
StableDiffusionXLInpaintPipeline,
134-
StableDiffusionXLInstructPix2PixPipeline,
135-
StableDiffusionXLPipeline,
136-
)
137-
138130
try:
139131
if not is_onnx_available():
140132
raise OptionalDependencyNotAvailable()

src/diffusers/pipelines/controlnet/__init__.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
from ...utils import (
22
OptionalDependencyNotAvailable,
33
is_flax_available,
4-
is_invisible_watermark_available,
54
is_torch_available,
65
is_transformers_available,
76
)
87

98

10-
try:
11-
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
12-
raise OptionalDependencyNotAvailable()
13-
except OptionalDependencyNotAvailable:
14-
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
15-
else:
16-
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
17-
18-
199
try:
2010
if not (is_transformers_available() and is_torch_available()):
2111
raise OptionalDependencyNotAvailable()
@@ -26,6 +16,7 @@
2616
from .pipeline_controlnet import StableDiffusionControlNetPipeline
2717
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
2818
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
19+
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
2920

3021

3122
if is_transformers_available() and is_flax_available():

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch.nn.functional as F
2323
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
2424

25+
from diffusers.utils.import_utils import is_invisible_watermark_available
26+
2527
from ...image_processor import VaeImageProcessor
2628
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2729
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
@@ -42,7 +44,11 @@
4244
)
4345
from ..pipeline_utils import DiffusionPipeline
4446
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
45-
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
47+
48+
49+
if is_invisible_watermark_available():
50+
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
51+
4652
from .multicontrolnet import MultiControlNetModel
4753

4854

@@ -109,6 +115,7 @@ def __init__(
109115
controlnet: ControlNetModel,
110116
scheduler: KarrasDiffusionSchedulers,
111117
force_zeros_for_empty_prompt: bool = True,
118+
add_watermarker: Optional[bool] = None,
112119
):
113120
super().__init__()
114121

@@ -130,7 +137,13 @@ def __init__(
130137
self.control_image_processor = VaeImageProcessor(
131138
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
132139
)
133-
self.watermark = StableDiffusionXLWatermarker()
140+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
141+
142+
if add_watermarker:
143+
self.watermark = StableDiffusionXLWatermarker()
144+
else:
145+
self.watermark = None
146+
134147
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
135148

136149
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
@@ -995,7 +1008,10 @@ def __call__(
9951008
image = latents
9961009
return StableDiffusionXLPipelineOutput(images=image)
9971010

998-
image = self.watermark.apply_watermark(image)
1011+
# apply watermark if available
1012+
if self.watermark is not None:
1013+
image = self.watermark.apply_watermark(image)
1014+
9991015
image = self.image_processor.postprocess(image, output_type=output_type)
10001016

10011017
# Offload last model to CPU

src/diffusers/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ...utils import (
88
BaseOutput,
99
OptionalDependencyNotAvailable,
10-
is_invisible_watermark_available,
1110
is_torch_available,
1211
is_transformers_available,
1312
)
@@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
2827

2928

3029
try:
31-
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
30+
if not (is_transformers_available() and is_torch_available()):
3231
raise OptionalDependencyNotAvailable()
3332
except OptionalDependencyNotAvailable:
34-
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
33+
from ...utils.dummy_torch_and_transformers_and_objects import * # noqa F403
3534
else:
3635
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
3736
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@
3232
from ...utils import (
3333
is_accelerate_available,
3434
is_accelerate_version,
35+
is_invisible_watermark_available,
3536
logging,
3637
randn_tensor,
3738
replace_example_docstring,
3839
)
3940
from ..pipeline_utils import DiffusionPipeline
4041
from . import StableDiffusionXLPipelineOutput
41-
from .watermark import StableDiffusionXLWatermarker
42+
43+
44+
if is_invisible_watermark_available():
45+
from .watermark import StableDiffusionXLWatermarker
4246

4347

4448
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -125,6 +129,7 @@ def __init__(
125129
unet: UNet2DConditionModel,
126130
scheduler: KarrasDiffusionSchedulers,
127131
force_zeros_for_empty_prompt: bool = True,
132+
add_watermarker: Optional[bool] = None,
128133
):
129134
super().__init__()
130135

@@ -142,7 +147,12 @@ def __init__(
142147
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
143148
self.default_sample_size = self.unet.config.sample_size
144149

145-
self.watermark = StableDiffusionXLWatermarker()
150+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
151+
152+
if add_watermarker:
153+
self.watermark = StableDiffusionXLWatermarker()
154+
else:
155+
self.watermark = None
146156

147157
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
148158
def enable_vae_slicing(self):
@@ -839,7 +849,10 @@ def __call__(
839849
image = latents
840850
return StableDiffusionXLPipelineOutput(images=image)
841851

842-
image = self.watermark.apply_watermark(image)
852+
# apply watermark if available
853+
if self.watermark is not None:
854+
image = self.watermark.apply_watermark(image)
855+
843856
image = self.image_processor.postprocess(image, output_type=output_type)
844857

845858
# Offload last model to CPU

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@
3333
from ...utils import (
3434
is_accelerate_available,
3535
is_accelerate_version,
36+
is_invisible_watermark_available,
3637
logging,
3738
randn_tensor,
3839
replace_example_docstring,
3940
)
4041
from ..pipeline_utils import DiffusionPipeline
4142
from . import StableDiffusionXLPipelineOutput
42-
from .watermark import StableDiffusionXLWatermarker
43+
44+
45+
if is_invisible_watermark_available():
46+
from .watermark import StableDiffusionXLWatermarker
4347

4448

4549
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -131,6 +135,7 @@ def __init__(
131135
scheduler: KarrasDiffusionSchedulers,
132136
requires_aesthetics_score: bool = False,
133137
force_zeros_for_empty_prompt: bool = True,
138+
add_watermarker: Optional[bool] = None,
134139
):
135140
super().__init__()
136141

@@ -148,7 +153,12 @@ def __init__(
148153
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
149154
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
150155

151-
self.watermark = StableDiffusionXLWatermarker()
156+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
157+
158+
if add_watermarker:
159+
self.watermark = StableDiffusionXLWatermarker()
160+
else:
161+
self.watermark = None
152162

153163
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
154164
def enable_vae_slicing(self):
@@ -990,7 +1000,10 @@ def denoising_value_valid(dnv):
9901000
image = latents
9911001
return StableDiffusionXLPipelineOutput(images=image)
9921002

993-
image = self.watermark.apply_watermark(image)
1003+
# apply watermark if available
1004+
if self.watermark is not None:
1005+
image = self.watermark.apply_watermark(image)
1006+
9941007
image = self.image_processor.postprocess(image, output_type=output_type)
9951008

9961009
# Offload last model to CPU

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,20 @@
3030
XFormersAttnProcessor,
3131
)
3232
from ...schedulers import KarrasDiffusionSchedulers
33-
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
33+
from ...utils import (
34+
is_accelerate_available,
35+
is_accelerate_version,
36+
is_invisible_watermark_available,
37+
logging,
38+
randn_tensor,
39+
replace_example_docstring,
40+
)
3441
from ..pipeline_utils import DiffusionPipeline
3542
from . import StableDiffusionXLPipelineOutput
36-
from .watermark import StableDiffusionXLWatermarker
43+
44+
45+
if is_invisible_watermark_available():
46+
from .watermark import StableDiffusionXLWatermarker
3747

3848

3949
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -265,6 +275,7 @@ def __init__(
265275
scheduler: KarrasDiffusionSchedulers,
266276
requires_aesthetics_score: bool = False,
267277
force_zeros_for_empty_prompt: bool = True,
278+
add_watermarker: Optional[bool] = None,
268279
):
269280
super().__init__()
270281

@@ -282,7 +293,12 @@ def __init__(
282293
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
283294
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
284295

285-
self.watermark = StableDiffusionXLWatermarker()
296+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
297+
298+
if add_watermarker:
299+
self.watermark = StableDiffusionXLWatermarker()
300+
else:
301+
self.watermark = None
286302

287303
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
288304
def enable_vae_slicing(self):
@@ -1266,6 +1282,10 @@ def denoising_value_valid(dnv):
12661282
else:
12671283
return StableDiffusionXLPipelineOutput(images=latents)
12681284

1285+
# apply watermark if available
1286+
if self.watermark is not None:
1287+
image = self.watermark.apply_watermark(image)
1288+
12691289
image = self.image_processor.postprocess(image, output_type=output_type)
12701290

12711291
# Offload last model to CPU

0 commit comments

Comments
 (0)