Skip to content

Commit 32ea214

Browse files
[Kandinsky] Improve kandinsky API a bit (#3636)
* Improve docs * up * Update docs/source/en/api/pipelines/kandinsky.mdx * up * up * correct more * further improve * Update docs/source/en/api/pipelines/kandinsky.mdx Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 55dbfa0 commit 32ea214

10 files changed

+182
-119
lines changed

docs/source/en/api/pipelines/kandinsky.mdx

Lines changed: 119 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,81 +19,78 @@ The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene55
1919

2020
## Available Pipelines:
2121

22-
| Pipeline | Tasks | Colab
23-
|---|---|:---:|
24-
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | - |
25-
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | - |
26-
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | - |
22+
| Pipeline | Tasks |
23+
|---|---|
24+
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
25+
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
26+
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
2727

2828
## Usage example
2929

30-
In the following, we will walk you through some cool examples of using the Kandinsky pipelines to create some visually aesthetic artwork.
30+
In the following, we will walk you through some examples of how to use the Kandinsky pipelines to create some visually aesthetic artwork.
3131

3232
### Text-to-Image Generation
3333

34-
For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). Let's throw a fun prompt at Kandinsky to see what it comes up with :)
34+
For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`].
35+
The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings,
36+
as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf).
37+
Let's throw a fun prompt at Kandinsky to see what it comes up with.
3538

36-
```python
39+
```py
3740
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
38-
negative_prompt = "low quality, bad quality"
3941
```
4042

41-
We will pass both the `prompt` and `negative_prompt` to our prior diffusion pipeline. In contrast to other diffusion pipelines, such as Stable Diffusion, the `prompt` and `negative_prompt` shall be passed separately so that we can retrieve a CLIP image embedding for each prompt input. You can use `guidance_scale`, and `num_inference_steps` arguments to guide this process, just like how you would normally do with all other pipelines in diffusers.
43+
First, let's instantiate the prior pipeline and the text-to-image pipeline. Both
44+
pipelines are diffusion models.
4245

43-
```python
44-
from diffusers import KandinskyPriorPipeline
46+
47+
```py
48+
from diffusers import DiffusionPipeline
4549
import torch
4650

47-
# create prior
48-
pipe_prior = KandinskyPriorPipeline.from_pretrained(
49-
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
50-
)
51+
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
5152
pipe_prior.to("cuda")
5253

53-
generator = torch.Generator(device="cuda").manual_seed(12)
54-
image_emb = pipe_prior(
55-
prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
56-
).images
54+
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
55+
t2i_pipe.to("cuda")
56+
```
5757

58-
zero_image_emb = pipe_prior(
59-
negative_prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
60-
).images
58+
Now we pass the prompt through the prior to generate image embeddings. The prior
59+
returns both the image embeddings corresponding to the prompt and negative/unconditional image
60+
embeddings corresponding to an empty string.
61+
62+
```py
63+
generator = torch.Generator(device="cuda").manual_seed(12)
64+
image_embeds, negative_image_embeds = pipe_prior(prompt, generator=generator).to_tuple()
6165
```
6266

63-
Once we create the image embedding, we can use [`KandinskyPipeline`] to generate images.
67+
<Tip warning={true}>
6468

65-
```python
66-
from PIL import Image
67-
from diffusers import KandinskyPipeline
69+
The text-to-image pipeline expects both `image_embeds`, `negative_image_embeds` and the original
70+
`prompt` as the text-to-image pipeline uses another text encoder to better guide the second diffusion
71+
process of `t2i_pipe`.
6872

73+
By default, the prior returns unconditioned negative image embeddings corresponding to the negative prompt of `""`.
74+
For better results, you can also pass a `negative_prompt` to the prior. This will increase the effective batch size
75+
of the prior by a factor of 2.
6976

70-
def image_grid(imgs, rows, cols):
71-
assert len(imgs) == rows * cols
77+
```py
78+
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
79+
negative_prompt = "low quality, bad quality"
7280

73-
w, h = imgs[0].size
74-
grid = Image.new("RGB", size=(cols * w, rows * h))
75-
grid_w, grid_h = grid.size
81+
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple()
82+
```
7683

77-
for i, img in enumerate(imgs):
78-
grid.paste(img, box=(i % cols * w, i // cols * h))
79-
return grid
84+
</Tip>
8085

8186

82-
# create diffuser pipeline
83-
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
84-
pipe.to("cuda")
87+
Next, we can pass the embeddings as well as the prompt to the text-to-image pipeline. Remember that
88+
in case you are using a customized negative prompt, that you should pass this one also to the text-to-image pipelines
89+
with `negative_prompt=negative_prompt`:
8590

86-
images = pipe(
87-
prompt,
88-
image_embeds=image_emb,
89-
negative_image_embeds=zero_image_emb,
90-
num_images_per_prompt=2,
91-
height=768,
92-
width=768,
93-
num_inference_steps=100,
94-
guidance_scale=4.0,
95-
generator=generator,
96-
).images
91+
```py
92+
image = t2i_pipe(prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
93+
image.save("cheeseburger_monster.png")
9794
```
9895

9996
One cheeseburger monster coming up! Enjoy!
@@ -164,22 +161,15 @@ prompt = "A fantasy landscape, Cinematic lighting"
164161
negative_prompt = "low quality, bad quality"
165162

166163
generator = torch.Generator(device="cuda").manual_seed(30)
167-
image_emb = pipe_prior(
168-
prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
169-
).images
170-
171-
zero_image_emb = pipe_prior(
172-
negative_prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
173-
).images
164+
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple()
174165

175166
out = pipe(
176167
prompt,
177168
image=original_image,
178-
image_embeds=image_emb,
179-
negative_image_embeds=zero_image_emb,
169+
image_embeds=image_embeds,
170+
negative_image_embeds=negative_image_embeds,
180171
height=768,
181172
width=768,
182-
num_inference_steps=500,
183173
strength=0.3,
184174
)
185175

@@ -193,7 +183,7 @@ out.images[0].save("fantasy_land.png")
193183

194184
You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat.
195185

196-
```python
186+
```py
197187
from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
198188
from diffusers.utils import load_image
199189
import torch
@@ -205,7 +195,7 @@ pipe_prior = KandinskyPriorPipeline.from_pretrained(
205195
pipe_prior.to("cuda")
206196

207197
prompt = "a hat"
208-
image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
198+
prior_output = pipe_prior(prompt)
209199

210200
pipe = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
211201
pipe.to("cuda")
@@ -222,8 +212,7 @@ out = pipe(
222212
prompt,
223213
image=init_image,
224214
mask_image=mask,
225-
image_embeds=image_emb,
226-
negative_image_embeds=zero_image_emb,
215+
**prior_output,
227216
height=768,
228217
width=768,
229218
num_inference_steps=150,
@@ -246,7 +235,6 @@ from diffusers.utils import load_image
246235
import PIL
247236

248237
import torch
249-
from torchvision import transforms
250238

251239
pipe_prior = KandinskyPriorPipeline.from_pretrained(
252240
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
@@ -263,22 +251,80 @@ img2 = load_image(
263251

264252
# add all the conditions we want to interpolate, can be either text or image
265253
images_texts = ["a cat", img1, img2]
254+
266255
# specify the weights for each condition in images_texts
267256
weights = [0.3, 0.3, 0.4]
268-
image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
257+
258+
# We can leave the prompt empty
259+
prompt = ""
260+
prior_out = pipe_prior.interpolate(images_texts, weights)
269261

270262
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
271263
pipe.to("cuda")
272264

273-
image = pipe(
274-
"", image_embeds=image_emb, negative_image_embeds=zero_image_emb, height=768, width=768, num_inference_steps=150
275-
).images[0]
265+
image = pipe(prompt, **prior_out, height=768, width=768).images[0]
276266

277267
image.save("starry_cat.png")
278268
```
279269
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png)
280270

281271

272+
## Optimization
273+
274+
Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`]
275+
and a second image decoding pipeline which is one of [`KandinskyPipeline`], [`KandinskyImg2ImgPipeline`], or [`KandinskyInpaintPipeline`].
276+
277+
The bulk of the computation time will always be the second image decoding pipeline, so when looking
278+
into optimizing the model, one should look into the second image decoding pipeline.
279+
280+
When running with PyTorch < 2.0, we strongly recommend making use of [`xformers`](https://github.com/facebookresearch/xformers)
281+
to speed-up the optimization. This can be done by simply running:
282+
283+
```py
284+
from diffusers import DiffusionPipeline
285+
import torch
286+
287+
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
288+
t2i_pipe.enable_xformers_memory_efficient_attention()
289+
```
290+
291+
When running on PyTorch >= 2.0, PyTorch's SDPA attention will automatically be used. For more information on
292+
PyTorch's SDPA, feel free to have a look at [this blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/).
293+
294+
To have explicit control , you can also manually set the pipeline to use PyTorch's 2.0 efficient attention:
295+
296+
```py
297+
from diffusers.models.attention_processor import AttnAddedKVProcessor2_0
298+
299+
t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor2_0())
300+
```
301+
302+
The slowest and most memory intense attention processor is the default `AttnAddedKVProcessor` processor.
303+
We do **not** recommend using it except for testing purposes or cases where very high determistic behaviour is desired.
304+
You can set it with:
305+
306+
```py
307+
from diffusers.models.attention_processor import AttnAddedKVProcessor
308+
309+
t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor())
310+
```
311+
312+
With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending
313+
on your hardware can signficantly speed-up your inference time once the model is compiled.
314+
To use Kandinsksy with `torch.compile`, you can do:
315+
316+
```py
317+
t2i_pipe.unet.to(memory_format=torch.channels_last)
318+
t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=True)
319+
```
320+
321+
After compilation you should see a very fast inference time. For more information,
322+
feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0).
323+
324+
325+
326+
327+
282328
## KandinskyPriorPipeline
283329

284330
[[autodoc]] KandinskyPriorPipeline
@@ -292,15 +338,14 @@ image.save("starry_cat.png")
292338
- all
293339
- __call__
294340

295-
## KandinskyInpaintPipeline
296-
297-
[[autodoc]] KandinskyInpaintPipeline
298-
- all
299-
- __call__
300-
301341
## KandinskyImg2ImgPipeline
302342

303343
[[autodoc]] KandinskyImg2ImgPipeline
304344
- all
305345
- __call__
306346

347+
## KandinskyInpaintPipeline
348+
349+
[[autodoc]] KandinskyInpaintPipeline
350+
- all
351+
- __call__

src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,12 @@ def __call__(
304304
prompt: Union[str, List[str]],
305305
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
306306
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
307+
negative_prompt: Optional[Union[str, List[str]]] = None,
307308
height: int = 512,
308309
width: int = 512,
309310
num_inference_steps: int = 100,
310311
guidance_scale: float = 4.0,
311312
num_images_per_prompt: int = 1,
312-
negative_prompt: Optional[Union[str, List[str]]] = None,
313313
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
314314
latents: Optional[torch.FloatTensor] = None,
315315
output_type: Optional[str] = "pil",
@@ -325,6 +325,9 @@ def __call__(
325325
The clip image embeddings for text prompt, that will be used to condition the image generation.
326326
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
327327
The clip image embeddings for negative text prompt, will be used to condition the image generation.
328+
negative_prompt (`str` or `List[str]`, *optional*):
329+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
330+
if `guidance_scale` is less than `1`).
328331
height (`int`, *optional*, defaults to 512):
329332
The height in pixels of the generated image.
330333
width (`int`, *optional*, defaults to 512):
@@ -340,9 +343,6 @@ def __call__(
340343
usually at the expense of lower image quality.
341344
num_images_per_prompt (`int`, *optional*, defaults to 1):
342345
The number of images to generate per prompt.
343-
negative_prompt (`str` or `List[str]`, *optional*):
344-
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
345-
if `guidance_scale` is less than `1`).
346346
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
347347
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
348348
to make generation deterministic.
@@ -418,7 +418,8 @@ def __call__(
418418
timestep=t,
419419
encoder_hidden_states=text_encoder_hidden_states,
420420
added_cond_kwargs=added_cond_kwargs,
421-
).sample
421+
return_dict=False,
422+
)[0]
422423

423424
if do_classifier_free_guidance:
424425
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,13 @@ def __call__(
368368
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
369369
image_embeds: torch.FloatTensor,
370370
negative_image_embeds: torch.FloatTensor,
371+
negative_prompt: Optional[Union[str, List[str]]] = None,
371372
height: int = 512,
372373
width: int = 512,
373374
num_inference_steps: int = 100,
374375
strength: float = 0.3,
375376
guidance_scale: float = 7.0,
376377
num_images_per_prompt: int = 1,
377-
negative_prompt: Optional[Union[str, List[str]]] = None,
378378
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
379379
output_type: Optional[str] = "pil",
380380
return_dict: bool = True,
@@ -392,6 +392,9 @@ def __call__(
392392
The clip image embeddings for text prompt, that will be used to condition the image generation.
393393
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
394394
The clip image embeddings for negative text prompt, will be used to condition the image generation.
395+
negative_prompt (`str` or `List[str]`, *optional*):
396+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
397+
if `guidance_scale` is less than `1`).
395398
height (`int`, *optional*, defaults to 512):
396399
The height in pixels of the generated image.
397400
width (`int`, *optional*, defaults to 512):
@@ -413,9 +416,6 @@ def __call__(
413416
usually at the expense of lower image quality.
414417
num_images_per_prompt (`int`, *optional*, defaults to 1):
415418
The number of images to generate per prompt.
416-
negative_prompt (`str` or `List[str]`, *optional*):
417-
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
418-
if `guidance_scale` is less than `1`).
419419
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
420420
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
421421
to make generation deterministic.
@@ -512,7 +512,8 @@ def __call__(
512512
timestep=t,
513513
encoder_hidden_states=text_encoder_hidden_states,
514514
added_cond_kwargs=added_cond_kwargs,
515-
).sample
515+
return_dict=False,
516+
)[0]
516517

517518
if do_classifier_free_guidance:
518519
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)

0 commit comments

Comments
 (0)