diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 6a895a6d0f29..3218eed31ef9 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -17,6 +17,7 @@ import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.sampling import get_sigmas_karras from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler @@ -400,6 +401,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + use_karras_sigmas: Optional[bool] = False, ): r""" Function invoked when calling the pipeline for generation. @@ -456,7 +458,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to + `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M + Karras`. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. @@ -494,10 +499,18 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) - sigmas = self.scheduler.sigmas + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas sigmas = sigmas.to(prompt_embeds.dtype) - # 5. Prepare latent variables + # 6. Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -513,7 +526,7 @@ def __call__( self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) - # 6. Define model function + # 7. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) @@ -524,16 +537,16 @@ def model_fn(x, t): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred - # 7. Run k-diffusion solver + # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) - # 8. Post-processing + # 9. Post-processing image = self.decode_latents(latents) - # 9. Run safety checker + # 10. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py index 7869790c6218..546b1d21252c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py @@ -75,3 +75,32 @@ def test_stable_diffusion_2(self): expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112]) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1 + + def test_stable_diffusion_karras_sigmas(self): + sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + sd_pipe.set_scheduler("sample_dpmpp_2m") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=7.5, + num_inference_steps=15, + output_type="np", + use_karras_sigmas=True, + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2