Skip to content

Commit 9ff7727

Browse files
dg845ayushmangalayushtuespatrickvonplaten
authored andcommitted
Add Consistency Models Pipeline (huggingface#3492)
* initial commit * Improve consistency models sampling implementation. * Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling. * Add Unet blocks for consistency models * Add conversion script for Unet * Fix bug in new unet blocks * Fix attention weight loading * Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests. * make style * Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints. * Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script. * make style * Change num_class_embeds to 1000 to better match the original consistency models implementation. * Add support for distillation in pipeline_consistency_models.py. * Improve consistency model tests: - Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests * make style * Improve ConsistencyModelPipeline: - Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues * make style * Improve ConsistencyModelPipeline: - add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__ * make style * Fix bug when randomly generating class labels for class-conditional generation. * Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests. * Remove some unused code and make style. * Fix small bug in CMStochasticIterativeScheduler. * Add expected slices for multistep sampling tests and make them pass. * Work on consistency model fast tests: - in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas * make style * Refactor conversion script to make it easier to add more model architectures to convert in the future. * Work on ConsistencyModelPipeline tests: - Fix device bug when handling class labels in ConsistencyModelPipeline.__call__ - Add slow tests for onestep and multistep sampling and make them pass - Refactor fast tests - Refactor ConsistencyModelPipeline.__init__ * make style * Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now. * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler. * Make fast tests from PipelineTesterMixin pass. * make style * Refactor consistency models pipeline and scheduler: - Remove support for Karras schedulers (only support CMStochasticIterativeScheduler) - Move sigma manipulation, input scaling, denoising from pipeline to scheduler - Make corresponding changes to tests and ensure they pass * make style * Add docstrings and further refactor pipeline and scheduler. * make style * Add initial version of the consistency models documentation. * Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging. * make style * Convert current slow tests to use fp16 and flash attention. * make style * Add slow tests for normal attention on cuda device. * make style * Fix attention weights loading * Update consistency model fast tests for new test checkpoints with attention fix. * make style * apply suggestions * Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler). * Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers. * When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence). * make style * Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example. * apply suggestions from review * make style * fix attention naming * Add tests for CMStochasticIterativeScheduler. * make style * Make CMStochasticIterativeScheduler tests pass. * make style * Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest. * make style * rename some models * Improve API * rename some models * Remove duplicated block * Add docstring and make torch compile work * More fixes * Fixes * Apply suggestions from code review * Apply suggestions from code review * add more docstring * update consistency conversion script --------- Co-authored-by: ayushmangal <[email protected]> Co-authored-by: Ayush Mangal <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent a166fd6 commit 9ff7727

17 files changed

+1710
-13
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@
184184
title: Audio Diffusion
185185
- local: api/pipelines/audioldm
186186
title: AudioLDM
187+
- local: api/pipelines/consistency_models
188+
title: Consistency Models
187189
- local: api/pipelines/controlnet
188190
title: ControlNet
189191
- local: api/pipelines/cycle_diffusion
@@ -274,6 +276,8 @@
274276
- sections:
275277
- local: api/schedulers/overview
276278
title: Overview
279+
- local: api/schedulers/cm_stochastic_iterative
280+
title: Consistency Model Multistep Scheduler
277281
- local: api/schedulers/ddim
278282
title: DDIM
279283
- local: api/schedulers/ddim_inverse
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Consistency Models
2+
3+
Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
4+
5+
The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows:
6+
7+
*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. *
8+
9+
Resources:
10+
11+
* [Paper](https://arxiv.org/abs/2303.01469)
12+
* [Original Code](https://github.com/openai/consistency_models)
13+
14+
Available Checkpoints are:
15+
- *cd_imagenet64_l2 (64x64 resolution)* [openai/consistency-model-pipelines](https://huggingface.co/openai/consistency-model-pipelines)
16+
- *cd_imagenet64_lpips (64x64 resolution)* [openai/diffusers-cd_imagenet64_lpips](https://huggingface.co/openai/diffusers-cd_imagenet64_lpips)
17+
- *ct_imagenet64 (64x64 resolution)* [openai/diffusers-ct_imagenet64](https://huggingface.co/openai/diffusers-ct_imagenet64)
18+
- *cd_bedroom256_l2 (256x256 resolution)* [openai/diffusers-cd_bedroom256_l2](https://huggingface.co/openai/diffusers-cd_bedroom256_l2)
19+
- *cd_bedroom256_lpips (256x256 resolution)* [openai/diffusers-cd_bedroom256_lpips](https://huggingface.co/openai/diffusers-cd_bedroom256_lpips)
20+
- *ct_bedroom256 (256x256 resolution)* [openai/diffusers-ct_bedroom256](https://huggingface.co/openai/diffusers-ct_bedroom256)
21+
- *cd_cat256_l2 (256x256 resolution)* [openai/diffusers-cd_cat256_l2](https://huggingface.co/openai/diffusers-cd_cat256_l2)
22+
- *cd_cat256_lpips (256x256 resolution)* [openai/diffusers-cd_cat256_lpips](https://huggingface.co/openai/diffusers-cd_cat256_lpips)
23+
- *ct_cat256 (256x256 resolution)* [openai/diffusers-ct_cat256](https://huggingface.co/openai/diffusers-ct_cat256)
24+
25+
## Available Pipelines
26+
27+
| Pipeline | Tasks | Demo | Colab |
28+
|:---:|:---:|:---:|:---:|
29+
| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | |
30+
31+
This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) :heart:
32+
33+
## Usage Example
34+
35+
```python
36+
import torch
37+
38+
from diffusers import ConsistencyModelPipeline
39+
40+
device = "cuda"
41+
# Load the cd_imagenet64_l2 checkpoint.
42+
model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
43+
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
44+
pipe.to(device)
45+
46+
# Onestep Sampling
47+
image = pipe(num_inference_steps=1).images[0]
48+
image.save("consistency_model_onestep_sample.png")
49+
50+
# Onestep sampling, class-conditional image generation
51+
# ImageNet-64 class label 145 corresponds to king penguins
52+
image = pipe(num_inference_steps=1, class_labels=145).images[0]
53+
image.save("consistency_model_onestep_sample_penguin.png")
54+
55+
# Multistep sampling, class-conditional image generation
56+
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo.
57+
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
58+
image = pipe(timesteps=[22, 0], class_labels=145).images[0]
59+
image.save("consistency_model_multistep_sample_penguin.png")
60+
```
61+
62+
For an additional speed-up, one can also make use of `torch.compile`. Multiple images can be generated in <1 second as follows:
63+
64+
```py
65+
import torch
66+
from diffusers import ConsistencyModelPipeline
67+
68+
device = "cuda"
69+
# Load the cd_bedroom256_lpips checkpoint.
70+
model_id_or_path = "openai/diffusers-cd_bedroom256_lpips"
71+
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
72+
pipe.to(device)
73+
74+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
75+
76+
# Multistep sampling
77+
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
78+
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
79+
for _ in range(10):
80+
image = pipe(timesteps=[17, 0]).images[0]
81+
image.show()
82+
```
83+
84+
## ConsistencyModelPipeline
85+
[[autodoc]] ConsistencyModelPipeline
86+
- all
87+
- __call__
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Consistency Model Multistep Scheduler
2+
3+
## Overview
4+
5+
Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
6+
Based on the [original consistency models implementation](https://github.com/openai/consistency_models).
7+
Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps.
8+
9+
## CMStochasticIterativeScheduler
10+
[[autodoc]] CMStochasticIterativeScheduler
11+

0 commit comments

Comments
 (0)