Skip to content

Add Consistency Models Pipeline #3492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
Jul 5, 2023

Conversation

dg845
Copy link
Contributor

@dg845 dg845 commented May 20, 2023

This PR implements a pipeline for Consistency Models as discussed in #3073.

Model/Pipeline Description

Consistency Models (paper, code) are a new family of generative models similar to continuous-time diffusion models which support fast one-step generation. Diffusion models can be distilled into a consistency model for faster sampling, and consistency models can also be trained from scratch. From the paper abstract:

Diffusion models have made significant breakthroughs in image, audio, and video generation, but they depend on an iterative generation process that causes slow sampling speed and caps their potential for real-time applications. To overcome this limitation, we propose consistency models, a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also support zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either as a way to distill pre-trained diffusion models, or as standalone generative models. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step generation. For example, we achieve the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained as standalone generative models, consistency models also outperform single-step, non-adversarial generative models on standard benchmarks like CIFAR-10, ImageNet 64x64 and LSUN 256x256.

In this PR, we implement a consistency model sampling pipeline as described in the paper.

consistency_models_diagram

Usage Examples

import torch

from diffusers import ConsistencyModelPipeline

device = "cuda"
# Load the cd_imagenet64_l2 checkpoint.
model_id_or_path = "dg845/consistency-model-pipelines"
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)

# Onestep Sampling
image = pipe(num_inference_steps=1).images[0]
image.save("consistency_model_onestep_sample.png")

# Onestep sampling, class-conditional image generation
# ImageNet-64 class label 145 corresponds to king penguins
image = pipe(num_inference_steps=1, class_labels=145).images[0]
image.save("consistency_model_onestep_sample_penguin.png")

# Multistep sampling, class-conditional image generation
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo.
image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
image.save("consistency_model_multistep_sample_penguin.png")

TODO

  • Implement Consistency Models pipeline (ConsistencyModelPipeline)
  • Script to convert Consistency Model checkpoints to diffusers checkpoints
  • Upload pre-trained Consistency Model checkpoints [original checkpoints]
  • Add model cards for checkpoints and move checkpoints to appropriate repositories
  • Create tests for Consistency Models
  • Add docstrings for new classes
  • Create documentation for Consistency Models
  • Add usage example(s)

Discussion

  • (TBD)

CC

@patrickvonplaten
@yang-song (author of original paper and code)
@tyshiwo1
@ayushtues
@xenova

dg845 added 3 commits May 15, 2023 08:04
…ampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.
@dg845 dg845 changed the title Consistency models pipeline [WIP] Add Consistency models pipeline May 20, 2023
@dg845 dg845 mentioned this pull request May 20, 2023
2 tasks
@dg845 dg845 changed the title [WIP] Add Consistency models pipeline [WIP] Add Consistency Models Pipeline May 20, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@dg845
Copy link
Contributor Author

dg845 commented May 20, 2023

The sampling code is still very much a work in progress right now, I think I need to understand Karras schedulers better 😅 .

@ayushtues, not sure what the best way to collaborate is. One suggestion is that I can continue to work on the sampling code and you can start work on scripts/examples to distill diffusion models to consistency models, which I think people will probably want alongside a consistency models pipeline. Also open to other suggestions :).

@ayushtues
Copy link
Contributor

ayushtues commented May 20, 2023

Hey @dg845 if you're working on sampling, I can pick up on replicating the model architecture and porting the weights from the original checkpoints to diffusers for now, does that sound okay?

@ayushtues
Copy link
Contributor

ayushtues commented May 20, 2023

As a side note, have you been able to sample any good images from the original repo? I was trying to sample some Imagenet images, but could only get random garbage.

I was not able to install Flash Attention, but even using normal attention should not give such different results

I created a branch with some changes to the original code, removing distributed training/argparse and stuff to make it easily usable here: https://github.com/ayushtues/consistency_models/tree/local, feel free to use it.

Also set use_fp16 to false, it was leading to numerical unstability and giving NaNs

@ayushtues
Copy link
Contributor

ayushtues commented May 20, 2023

Also for the sampler, can we not use HeunDiscreteScheduler or EulerDiscreteScheduler already present in Diffusers?

@dg845
Copy link
Contributor Author

dg845 commented May 20, 2023

Hey @dg845 if you're working on sampling, I can pick up on replicating the model architecture and porting the weights from the original checkpoints to diffusers for now, does that sound okay?

Sounds good :).

As a side note, have you been able to sample any good images from the original repo? I was trying to sample some Imagenet images, but could only get random garbage.

I ran into a problem installing the dependencies (I think it was related to distributed training/sampling) and haven't fixed it yet. I might try out your branch.

Also for the sampler, can we not use HeunDiscreteScheduler or EulerDiscreteScheduler already present in Diffusers?

In my understanding (admittedly a little shaky), the original implementation supports two types of schedulers:

  1. Karras-style schedulers like Euler or Heun.
  2. The multi-step scheduler (e.g. stochastic_iterative_sampler in the original code), of which the one-step scheduler could be considered a special case. This implements Algorithm 1 in the paper and doesn't correspond to any currently implemented scheduler, although KarrasVeScheduler is similar (to the best of my knowledge).

I think what I have currently supports Karras-style schedulers (and leverages already implemented Karras schedulers like HeunDiscreteScheduler), but I'm not sure about the best way to implement the multi-step scheduler.

@dg845
Copy link
Contributor Author

dg845 commented May 21, 2023

Hi, would someone be able to do a preliminary design review of ConsistencyModelPipeline and CMStochasticIterativeScheduler? In particular, I have a few questions:

  1. Should CMStochasticIterativeScheduler.__init__ configure a beta schedule like most of the samplers supported in the original implementation (e.g. Heun, DPM)? Right now it takes arguments in "original" sigma space, like KarrasDenoiser in the original implementation or KarrasVeScheduler.
  2. Does it make sense to support schedulers like KarrasVeScheduler or ScoreSdeVeScheduler (since their implementation is different from the KarrasDiffusionSchedulers)? These schedulers aren't supported in the original code but I'm not sure if they're theoretically compatible with consistency models. Right now I'm trying to support them.

@ayushtues
Copy link
Contributor

ayushtues commented May 21, 2023

Hey @dg845 I'll also help in reviewing the schedulers, wanted to understand how they work.

Wasn't able to get much work done over the weekend, will update soon in the upcoming week!

@ayushtues
Copy link
Contributor

Started a PR for adding the Unet and conversion scripts dg845#1.

@patrickvonplaten
Copy link
Contributor

Also cc @williamberman here

@williamberman
Copy link
Contributor

Awesome, thanks for getting this started! My high level understanding of the paper is that the consistency model directly outputs the unnoised result. If my understanding is correct, do we need a separate consistency model scheduler? If we want to do an iterative denoising process, we can re-use an existing scheduler then?

@dg845
Copy link
Contributor Author

dg845 commented May 23, 2023

I'm not sure about what the best design should be. Currently, the new scheduler CMStochasticIterativeScheduler's step() method is trivial because (unless I'm misunderstanding) after we add_noise_to_input, all we need to do is to return the consistency model evaluated at the noised sample sample_hat:

# Assume model output is the consistency model evaluated at sample_hat.
sample_prev = model_output
if not return_dict:
return (sample_prev,)

An alternative design would be to put the get_sigmas_karras/add_noise_to_input methods into the ConsistencyModelPipeline pipeline and then use those to run multi-step sampling. Under this design the pipeline kind of takes on some of the functions that a scheduler normally might, which is why I'm not sure about it.

To the best of my knowledge, I don't think there is a currently-implemented scheduler which is equivalent to the multi-step sampler in the original code; KarrasVeScheduler is probably the most similar, but the add_noise_to_input and step methods are not the same. For other samplers used by the original implementation, such as Euler and Heun, the existing diffusers implementations (e.g. EulerDiscreteScheduler, etc.) should be reused.

@ayushtues
Copy link
Contributor

ayushtues commented May 23, 2023

Was able to convert the Unet into diffusers, load pretrained checkpoints and get same outputs from forward passes:

ckpt : https://huggingface.co/ayushtues/consistency_models/tree/main

unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2")

@patrickvonplaten
Copy link
Contributor

Was able to convert the Unet into diffusers, load pretrained checkpoints and get same outputs from forward passes:

ckpt : https://huggingface.co/ayushtues/consistency_models/tree/main

unet = UNet2DModel.from_pretrained("ayushtues/consistency_models", subfolder="diffusers_cd_imagenet64_l2")

Wow great work @ayushtues ! Exciting!

@williamberman
Copy link
Contributor

williamberman commented May 23, 2023

I'm not sure about what the best design should be. Currently, the new scheduler CMStochasticIterativeScheduler's step() method is trivial because (unless I'm misunderstanding) after we add_noise_to_input, all we need to do is to return the consistency model evaluated at the noised sample sample_hat:

# Assume model output is the consistency model evaluated at sample_hat.
sample_prev = model_output
if not return_dict:
return (sample_prev,)

An alternative design would be to put the get_sigmas_karras/add_noise_to_input methods into the ConsistencyModelPipeline pipeline and then use those to run multi-step sampling. Under this design the pipeline kind of takes on some of the functions that a scheduler normally might, which is why I'm not sure about it.

To the best of my knowledge, I don't think there is a currently-implemented scheduler which is equivalent to the multi-step sampler in the original code; KarrasVeScheduler is probably the most similar, but the add_noise_to_input and step methods are not the same. For other samplers used by the original implementation, such as Euler and Heun, the existing diffusers implementations (e.g. EulerDiscreteScheduler, etc.) should be reused.

I'd avoid putting them on the pipeline, can you state how the two methods are different? I'm not familiar off the top of my head. If the only way the methods are different is in the configured noise schedule but the formulas are the same, we can re-use the existing scheduler

@dg845
Copy link
Contributor Author

dg845 commented May 24, 2023

For the add_noise_to_input method, KarrasVeScheduler and the consistency models implementation both make a noise update of the form

$$\hat{x}_n \leftarrow x + \sqrt{\tau_n^2 - \epsilon^2}\textbf{z}$$

where $\textbf{z} \sim \mathcal{N}(0, s_{noise}^2\textbf{I})$, but in KarrasVeScheduler $\epsilon = \sigma_{n}$ while in the consistency models implementation $\epsilon = \sigma_{min}$ (not in Karras space). Although $\tau_{n} = \hat{\sigma}_{n}$ in both implementations, $\hat{\sigma}_{n}$ is calculated differently, with KarrasVeScheduler using a gamma factor: $\hat{\sigma}_{n} = \sigma_n + \gamma \cdot \sigma_n$ while the consistency models implementation takes the next Karras sigma $\sigma_{n + 1}$ in the schedule as $\hat{\sigma}_n$.

For the step method, KarrasVeScheduler takes a Euler step from $\sigma_n$ to $\sigma_{n + 1}$, while the consistency model implementation simply evaluates the consistency model at $\hat{x}_n$ and $\hat{\sigma}_n$.

@williamberman
Copy link
Contributor

Ok great super helpful! I think I have my head wrapped around it.

  1. Just to note, the karras_ve and sde_ve are deprecated interfaces for interacting with diffeq based schedulers. We don't anymore explicitly call the higher order solving steps or noising in the pipelines. This has been encapsulated through multiple calls to call which tracks the current correction step and applies noise if needed. This means that the don't want to use the add_noise_to_input method.

  2. This does need a new scheduler as existing schedulers which noise a prediction_type of sample do not do so according to a sigma based noise schedule (ddpm is the only noising a sample prediction type I know off the top of my head and it uses the particular alpha beta parameterization). However, we can re-use components (sigma schedule, adding noise) from existing schedulers pretty easily as what the code you linked is doing is straightforward.

Sigma schedule: We can re-use all existing components for constructing the sigma schedule, see scheduling_dpmsolver_sde.py for an example of how it copy's the methods from the euler scheduler

Adding noise: I believe the noise adding step to update the prediction of x_0 to x_{t-1} is just our standard add_noise method for a sigma based scheduler but updated for a starting noise amount which in the case of the code you linked is always t_min. So for this, I'd just adapt the add_noise method from the euler scheduler with an additional parameter for the starting noise amount. (I'm guessing a bit here that it is a starting noise amount, if you could confirm that would be helpful).

[WIP] Add Unet for consistency models
@patrickvonplaten
Copy link
Contributor

We don't have to tackle it now but I think it should be a simple fix. Please correct me if I am wrong.

I was trying to get the UNet to compile with torch.compile() and when running inference it errors out:

Unsupported: call_function BuiltinVariable(iadd) [ConstantVariable(tuple), TupleVariable()] {}

from user code:
   File "/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d.py", line 283, in forward
    sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_blocks.py", line 935, in forward
    output_states += (hidden_states,)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Full code:

import torch
from diffusers import ConsistencyModelPipeline

device = "cuda"
model_id_or_path = "dg845/diffusers-cd_bedroom256_l2"
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

_ = pipe(num_inference_steps=5)

Happens when num_inference_steps is set to 1 too.

This is now fixed - torch compile works!

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten for now I have modified the tests in SchedulerCommonTest so that the tests for the new scheduler pass. Not super confident in the design, would appreciate if you could review :).

That works!

@patrickvonplaten patrickvonplaten changed the title [WIP] Add Consistency Models Pipeline Add Consistency Models Pipeline Jul 5, 2023
|:---:|:---:|:---:|:---:|
| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | |

This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) :heart:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure contributors are mentioned

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to be merged from my side. I pushed all checkpoints here:
https://huggingface.co/models?other=consistency-model

@ayushtues and @dg845 - really amazing work here 🚀

I've made sure to mention you on every model card here: https://huggingface.co/openai/diffusers-ct_cat256

as well as on the official docs: https://github.com/huggingface/diffusers/pull/3492/files#r1253304997

Hope that is ok?

@dg845
Copy link
Contributor Author

dg845 commented Jul 5, 2023

I've made sure to mention you on every model card here: https://huggingface.co/openai/diffusers-ct_cat256

as well as on the official docs: https://github.com/huggingface/diffusers/pull/3492/files#r1253304997

Hope that is ok?

Sounds good! Thanks for helping out with the PR :).

I think one additional thing that might need to be updated is scripts/convert_consistency_to_diffusers.py, which still uses old UNet2DModel configs with AttnDownsampleBlock2D/AttnUpsampleBlock2D instead of AttnDownBlock2D/AttnUpBlock2D. The conversion logic might also need to be changed slightly to match the new block keys.

Comment on lines +240 to +242
@require_torch_2
def test_consistency_model_cd_multistep_flash_attn(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
Copy link
Contributor

@williamberman williamberman Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a specific test for this? It seems outside of the scope of testing if this specific model works

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it won't really hurt tbh no?

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! one quick question about if we need the flash attention tests but looks great!

@patrickvonplaten patrickvonplaten merged commit aed7499 into huggingface:main Jul 5, 2023
sunhs pushed a commit to sunhs/diffusers that referenced this pull request Jul 6, 2023
* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
patrickvonplaten added a commit that referenced this pull request Jul 17, 2023
…__` (#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
@svjack
Copy link

svjack commented Jul 21, 2023

The output seems bad, Can you release a stable config and image output for generation

orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…__` (huggingface#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (huggingface#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (huggingface#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (huggingface#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (huggingface#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (huggingface#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (huggingface#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (huggingface#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (huggingface#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (huggingface#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (huggingface#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…__` (huggingface#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (huggingface#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (huggingface#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (huggingface#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (huggingface#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (huggingface#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (huggingface#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (huggingface#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (huggingface#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (huggingface#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (huggingface#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…__` (huggingface#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (huggingface#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (huggingface#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (huggingface#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (huggingface#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (huggingface#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (huggingface#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (huggingface#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (huggingface#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (huggingface#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (huggingface#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…__` (huggingface#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (huggingface#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (huggingface#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (huggingface#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (huggingface#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (huggingface#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (huggingface#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (huggingface#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (huggingface#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (huggingface#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (huggingface#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…__` (huggingface#3911)

* add noise_sampler to StableDiffusionKDiffusionPipeline

* fix/docs: Fix the broken doc links (huggingface#3897)

* fix/docs: Fix the broken doc links

Signed-off-by: GitHub <[email protected]>

* Update docs/source/en/using-diffusers/write_own_pipeline.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Add video img2img (huggingface#3900)

* Add image to image video

* Improve

* better naming

* make fix copies

* add docs

* finish tests

* trigger tests

* make style

* correct

* finish

* Fix more

* make style

* finish

* fix/doc-code: Updating to the latest version parameters (huggingface#3924)

fix/doc-code: update to use the new parameter

Signed-off-by: GitHub <[email protected]>

* fix/doc: no import torch issue (huggingface#3923)

Ffix/doc: no import torch issue

Signed-off-by: GitHub <[email protected]>

* Correct controlnet out of list error (huggingface#3928)

* Correct controlnet out of list error

* Apply suggestions from code review

* correct tests

* correct tests

* fix

* test all

* Apply suggestions from code review

* test all

* test all

* Apply suggestions from code review

* Apply suggestions from code review

* fix more tests

* Fix more

* Apply suggestions from code review

* finish

* Apply suggestions from code review

* Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

* finish

* Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807)

* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------

Co-authored-by: Patrick von Platen <[email protected]>

* [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)

* fixed typo

* updated doc to be consistent in naming

* make style/quality

* preprocessing for 4 channels and not 6

* make style

* test for 4c

* make style/quality

* fixed test on cpu

---------

Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>

* Improve memory text to video (huggingface#3930)

* Improve memory text to video

* Apply suggestions from code review

* add test

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* finish test setup

---------

Co-authored-by: Pedro Cuenca <[email protected]>

* revert automatic chunking (huggingface#3934)

* revert automatic chunking

* Apply suggestions from code review

* revert automatic chunking

* avoid upcasting by assigning dtype to noise tensor (huggingface#3713)

* avoid upcasting by assigning dtype to noise tensor

* make style

* Update train_unconditional.py

* Update train_unconditional.py

* make style

* add unit test for pickle

* revert change

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>

* Fix failing np tests (huggingface#3942)

* Fix failing np tests

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines_common.py

* Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947)

* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171

* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <[email protected]>

* Add Consistency Models Pipeline (huggingface#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* add test case for StableDiffusionKDiffusionPipeline noise_sampler

---------

Signed-off-by: GitHub <[email protected]>
Co-authored-by: Aisuko <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]>
Co-authored-by: estelleafl <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Aflalo <[email protected]>
Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: ayushmangal <[email protected]>
Co-authored-by: Ayush Mangal <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants