Skip to content

Migrate blog content to docs #2477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
title: Using KerasCV Stable Diffusion Checkpoints in Diffusers
title: Loading & Hub
- sections:
- local: using-diffusers/write_own_pipeline
title: Write your own inference pipeline
- local: using-diffusers/unconditional_image_generation
title: Unconditional Image Generation
- local: using-diffusers/conditional_image_generation
Expand Down
199 changes: 199 additions & 0 deletions docs/source/en/using-diffusers/write_own_pipeline.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Write your own inference pipeline

[[open-in-colab]]

🧨 Diffusers allows you to freely exchange models and schedulers in the pipeline to create your own custom pipeline for inference. This guide will show you how to create a custom Stable Diffusion pipeline for inference with the [`LMSDiscreteScheduler`] instead of the default [`PNDMScheduler`].

## Load pipeline components

The pretrained [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main) checkpoint includes all the components required to setup a complete diffusion pipeline. These components are stored in the following folders:

* text_encoder: the model used to generate the text representation of the input. Stable Diffusion uses CLIP, but other diffusion models may use other encoders such as BERT.
* tokenizer: it must match the one used by the `text_encoder` model.
* scheduler: the scheduling algorithm used to progressively add noise to the image during training.
* unet: the model used to generate the latent representation of the input.
* vae: the autoencoder module you'll use to decode latent representations into real images.

Load these components individually with the [`~ModelMixin.from_pretrained.subfolder`] parameter in the [`~ModelMixin.from_pretrained`] method:

<Tip>

💡 While you can load the entire checkpoint and all its components by calling the [`StableDiffusionPipeline`] with the [`~DiffusionPipeline.from_pretrained`] method, the goal of this guide is to show how you can pick and choose the individual models and scheduler you want to use to create a custom inference system.

</Tip>

```py
>>> from PIL import Image
>>> import torch
>>> from transformers import CLIPTextModel, CLIPTokenizer
>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
# 1. Load the autoencoder model which will be used to decode the latents into image space.

>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
# 2. Load the tokenizer and text encoder to tokenize and encode the text.

>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
>>> text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# 3. The UNet model for generating the latents.

>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
```

Instead of loading the default [`PNDMScheduler`], load the [`LMSDiscreteScheduler`] and feel free to configure some of the parameters:

```py
>>> from diffusers import LMSDiscreteScheduler

>>> scheduler = LMSDiscreteScheduler(
... beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
... )
```

Move the models to a GPU to speed up inference:

```py
>>> torch_device = "cuda"
>>> vae.to(torch_device)
>>> text_encoder.to(torch_device)
>>> unet.to(torch_device)
```

## Create text embeddings

The pipeline generates an image from a text prompt, so the next step is to tokenize the text to generate embeddings to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.

<Tip>

💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.

</Tip>

Feel free to choose any prompt you'd like!

```py
>>> prompt = ["a photograph of an astronaut riding a horse"]
>>> height = 512 # default height of Stable Diffusion
>>> width = 512 # default width of Stable Diffusion
>>> num_inference_steps = 100 # Number of denoising steps
>>> guidance_scale = 7.5 # Scale for classifier-free guidance
>>> generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
>>> batch_size = len(prompt)
```

First, tokenize the text and generate the text embeddings from the prompt:

```py
>>> text_input = tokenizer(
... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
... )

>>> text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
```

You'll also need to generate the *unconditional text embeddings* which are the embeddings for the padding token. These need to have the same shape (`batch_size` and `seq_length`) as the conditional `text_embeddings`:

```py
>>> max_length = text_input.input_ids.shape[-1]
>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
```

For unconditional image generation, you typically do two forward passes, one for each of the embeddings. But in practice, it is better to concatenate the embeddings into a single batch to avoid doing two forward passes:

```py
>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
```

## Create random noise

Next, generate some initial random noise as a starting point for the diffusion process. This is the latent representation of the image, and it'll be gradually denoised. At this point, the `latent` image is smaller than the final image size but that's okay though because the model will transform it into the final 512x512 image dimensions later.

```py
>>> latents = torch.randn(
... (batch_size, unet.in_channels, height // 8, width // 8),
... generator=generator,
... )
>>> latents = latents.to(torch_device)
```

## Denoise the image

One of the last steps is to create the denoising loop that'll progressively transform the pure noise in `latents` to an image described by your prompt.

Initialize the scheduler with the `num_inference_steps` to compute the `sigmas` (the noise scale value) and exact timestep values to use during denoising. You'll also need to scale the input by the inital noise distribution:

```py
>>> scheduler.set_timesteps(num_inference_steps)
>>> latents = latents * scheduler.init_noise_sigma
```

Finally, write a denoising loop that'll turn the noise into an image!

```py
>>> from tqdm.auto import tqdm

>>> scheduler.set_timesteps(num_inference_steps)

>>> for t in tqdm(scheduler.timesteps):
... # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
... latent_model_input = torch.cat([latents] * 2)

... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

... # predict the noise residual
... with torch.no_grad():
... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

... # perform guidance
... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

... # compute the previous noisy sample x_t -> x_t-1
... latents = scheduler.step(noise_pred, t, latents).prev_sample
```

## Decode the image

The final step is to use the `vae` to decode the latent representation into an image with the [`~diffusers.models.vae.DecoderOutput.sample`] method:

```py
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
```

Lastly, convert the image to a `PIL.Image` to see your generated image!

```py
>>> image = (image / 2 + 0.5).clamp(0, 1)
>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
>>> images = (image * 255).round().astype("uint8")
>>> pil_images = [Image.fromarray(image) for image in images]
>>> pil_images[0]
```

<div class="flex justify-center">
<img src="https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_k_lms.png"/>
</div>

## Summary

In this guide you learned how to:

* individually load the components of a Stable Diffusion model and replace the default scheduler with another one
* create text embeddings from the prompt to guide the UNet model
* write a denoising loop with the [`LMSDiscreteScheduler`] to generate an image
* decode and display the generated image