|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import importlib
|
| 16 | +import inspect |
16 | 17 | import warnings
|
17 | 18 | from typing import Callable, List, Optional, Union
|
18 | 19 |
|
19 | 20 | import torch
|
20 | 21 | from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
21 |
| -from k_diffusion.sampling import get_sigmas_karras |
| 22 | +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras |
22 | 23 |
|
23 | 24 | from ...image_processor import VaeImageProcessor
|
24 | 25 | from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
@@ -464,6 +465,7 @@ def __call__(
|
464 | 465 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
465 | 466 | callback_steps: int = 1,
|
466 | 467 | use_karras_sigmas: Optional[bool] = False,
|
| 468 | + noise_sampler_seed: Optional[int] = None, |
467 | 469 | ):
|
468 | 470 | r"""
|
469 | 471 | Function invoked when calling the pipeline for generation.
|
@@ -524,6 +526,8 @@ def __call__(
|
524 | 526 | Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
|
525 | 527 | `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
|
526 | 528 | Karras`.
|
| 529 | + noise_sampler_seed (`int`, *optional*, defaults to `None`): |
| 530 | + The random seed to use for the noise sampler. If `None`, a random seed will be generated. |
527 | 531 | Returns:
|
528 | 532 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
529 | 533 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
@@ -608,7 +612,14 @@ def model_fn(x, t):
|
608 | 612 | return noise_pred
|
609 | 613 |
|
610 | 614 | # 8. Run k-diffusion solver
|
611 |
| - latents = self.sampler(model_fn, latents, sigmas) |
| 615 | + sampler_kwargs = {} |
| 616 | + |
| 617 | + if "noise_sampler" in inspect.signature(self.sampler).parameters: |
| 618 | + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() |
| 619 | + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) |
| 620 | + sampler_kwargs["noise_sampler"] = noise_sampler |
| 621 | + |
| 622 | + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) |
612 | 623 |
|
613 | 624 | if not output_type == "latent":
|
614 | 625 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
|
0 commit comments