Skip to content

Commit e5ca2b2

Browse files
committed
add noise_sampler to StableDiffusionKDiffusionPipeline
1 parent 174dcd6 commit e5ca2b2

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import importlib
16+
import inspect
1617
import warnings
1718
from typing import Callable, List, Optional, Union
1819

1920
import torch
2021
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
21-
from k_diffusion.sampling import get_sigmas_karras
22+
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
2223

2324
from ...image_processor import VaeImageProcessor
2425
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
@@ -464,6 +465,7 @@ def __call__(
464465
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
465466
callback_steps: int = 1,
466467
use_karras_sigmas: Optional[bool] = False,
468+
noise_sampler_seed: Optional[int] = None,
467469
):
468470
r"""
469471
Function invoked when calling the pipeline for generation.
@@ -524,6 +526,8 @@ def __call__(
524526
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
525527
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
526528
Karras`.
529+
noise_sampler_seed (`int`, *optional*, defaults to `None`):
530+
The random seed to use for the noise sampler. If `None`, a random seed will be generated.
527531
Returns:
528532
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
529533
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -608,7 +612,14 @@ def model_fn(x, t):
608612
return noise_pred
609613

610614
# 8. Run k-diffusion solver
611-
latents = self.sampler(model_fn, latents, sigmas)
615+
sampler_kwargs = {}
616+
617+
if "noise_sampler" in inspect.signature(self.sampler).parameters:
618+
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
619+
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
620+
sampler_kwargs["noise_sampler"] = noise_sampler
621+
622+
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
612623

613624
if not output_type == "latent":
614625
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

0 commit comments

Comments
 (0)