diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index dddfc3591b66..c266e8b20e74 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -947,9 +947,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index c7a0db96e8c0..fd013c4974f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1040,9 +1040,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index bfaaaae49401..7de3f1dd9d88 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1275,9 +1275,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 25102ae7cf4a..367e401d57f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -374,7 +374,7 @@ def __call__( # predicted_original_sample instead of the noise_pred. So we need to compute the # predicted_original_sample here if we are using a karras style scheduler. if scheduler_is_in_sigma_space: - step_index = (self.scheduler.timesteps == t).nonzero().item() + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() sigma = self.scheduler.sigmas[step_index] noise_pred = latent_model_input - sigma * noise_pred diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 8ea001a882d0..56c362018c18 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -103,7 +103,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -125,6 +128,7 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -188,6 +192,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e72b1bdc23b5..d7c29d5488a5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -203,7 +203,6 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 - self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ @@ -225,13 +224,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc .astype(np.int64) ) - if self.use_karras_sigmas: - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 7fa8eabb5a15..721dd5e5bb85 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -202,7 +202,6 @@ def __init__( self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) - self.use_karras_sigmas = use_karras_sigmas def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -259,13 +258,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) - if self.use_karras_sigmas: - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2cce68f7d962..7233258a4766 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -117,6 +117,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +144,7 @@ def __init__( lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -201,6 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic .astype(np.int64) ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 1344d33a2552..3f16964fd567 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -29,13 +29,15 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class AltDiffusionPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = AltDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 906e1e7ee66f..a548983c3841 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -46,7 +46,11 @@ TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() @@ -97,7 +101,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): out_queue.join() -class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -207,7 +213,9 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) -class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionMultiControlNetPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index bc0e96b2f92b..c46593f03e5e 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -42,13 +42,19 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() -class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS @@ -161,7 +167,9 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) -class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionMultiControlNetPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 81647d968b6b..cf423f4c49d0 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -42,13 +42,19 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() -class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetInpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS @@ -237,7 +243,9 @@ def get_dummy_components(self): return components -class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class MultiControlNetInpaintPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 93abe7ae58bc..7daf3fcda4a2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -50,7 +50,7 @@ from ...models.test_lora_layers import create_unet_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -88,7 +88,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index e16478f06112..580c78675a92 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -33,14 +33,14 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() class StableDiffusionImageVariationPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionImageVariationPipeline params = IMAGE_VARIATION_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index eefbc83ce9d7..d1f7a49467e6 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -46,7 +46,7 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -84,7 +84,9 @@ def _test_img2img_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index f761f245883f..e7b084acb280 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -42,7 +42,7 @@ from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @@ -82,7 +82,9 @@ def _test_inpaint_compile(in_queue, out_queue, timeout): out_queue.join() -class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionInpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 691427b1c6eb..513e11c105d5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -40,14 +40,14 @@ TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() class StableDiffusionInstructPix2PixPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py index f47a70c4ece8..81d1baed5df6 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -32,14 +32,16 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @skip_mps -class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionModelEditingPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionModelEditingPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 6f41d2c43c8e..1b17f8b31be9 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -41,7 +41,11 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineLatentTesterMixin, + PipelineTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 33cc7f638ec2..a26abfa50096 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -36,13 +36,15 @@ from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2PipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 304ddacd2c36..b4d49c92425c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import require_torch_gpu from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False @@ -38,7 +38,7 @@ @skip_mps class StableDiffusionAttendAndExcitePipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionAttendAndExcitePipeline test_attention_slicing = False diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index f393967c7de4..fe2cf73da096 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -57,14 +57,16 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @skip_mps -class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionDepth2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionDepth2ImgPipeline test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 37c254f367f3..68a4b5132375 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -27,13 +27,15 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2InpaintPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index b94aaca4258a..5e1d610efcaf 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -21,6 +21,7 @@ import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, EulerDiscreteScheduler, @@ -28,17 +29,25 @@ StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +def check_same_shape(tensor_list): + shapes = [tensor.shape for tensor in tensor_list] + return all(shape == shapes[0] for shape in shapes[1:]) + + +class StableDiffusionLatentUpscalePipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionLatentUpscalePipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "height", @@ -185,6 +194,42 @@ def test_save_load_local(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=3e-3) + def test_karras_schedulers_shape(self): + skip_schedulers = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "HeunDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "KDPM2AncestralDiscreteScheduler", + "DPMSolverSDEScheduler", + ] + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + # make sure that PNDM does not need warm-up + pipe.scheduler.register_to_config(skip_prk_steps=True) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + outputs = [] + for scheduler_enum in KarrasDiffusionSchedulers: + if scheduler_enum.name in skip_schedulers: + # no sigma schedulers are not supported + # no schedulers + continue + + scheduler_cls = getattr(diffusers, scheduler_enum.name) + pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config) + output = pipe(**inputs)[0] + outputs.append(output) + + assert check_same_shape(outputs) + @require_torch_gpu @slow diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 4bbbad757edf..8d5edda16904 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -16,13 +16,20 @@ from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, require_torch_gpu, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() -class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableUnCLIPPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 741343066133..52581eb574e0 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -30,6 +30,7 @@ from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference, @@ -39,7 +40,9 @@ enable_full_determinism() -class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableUnCLIPImg2ImgPipelineFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableUnCLIPImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9fb3e167facc..52dd4afd6b21 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -14,6 +14,7 @@ import diffusers from diffusers import DiffusionPipeline from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -26,6 +27,11 @@ def to_np(tensor): return tensor +def check_same_shape(tensor_list): + shapes = [tensor.shape for tensor in tensor_list] + return all(shape == shapes[0] for shape in shapes[1:]) + + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. @@ -155,6 +161,46 @@ def test_latents_input(self): self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") +@require_torch +class PipelineKarrasSchedulerTesterMixin: + """ + This mixin is designed to be used with unittest.TestCase classes. + It provides a set of common tests for each PyTorch pipeline that makes use of KarrasDiffusionSchedulers + equivalence of dict and tuple outputs, etc. + """ + + def test_karras_schedulers_shape(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + # make sure that PNDM does not need warm-up + pipe.scheduler.register_to_config(skip_prk_steps=True) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = 2 + + if "strength" in inputs: + inputs["num_inference_steps"] = 4 + inputs["strength"] = 0.5 + + outputs = [] + for scheduler_enum in KarrasDiffusionSchedulers: + if "KDPM2" in scheduler_enum.name: + inputs["num_inference_steps"] = 5 + + scheduler_cls = getattr(diffusers, scheduler_enum.name) + pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config) + output = pipe(**inputs)[0] + outputs.append(output) + + if "KDPM2" in scheduler_enum.name: + inputs["num_inference_steps"] = 2 + + assert check_same_shape(outputs) + + @require_torch class PipelineTesterMixin: """