Skip to content

Commit 1074a46

Browse files
tolgacangozcmdr2
authored andcommitted
[Cont'd] Add the SDE variant of ~~DPM-Solver~~ and DPM-Solver++ to DPM Single Step (#8269)
* Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step --------- Co-authored-by: cmdr2 <[email protected]>
1 parent ebcab06 commit 1074a46

File tree

2 files changed

+74
-25
lines changed

2 files changed

+74
-25
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

+63-18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
2424
from ..utils import deprecate, logging
25+
from ..utils.torch_utils import randn_tensor
2526
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2627

2728

@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108109
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109110
`algorithm_type="dpmsolver++"`.
110111
algorithm_type (`str`, defaults to `dpmsolver++`):
111-
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
112-
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
113-
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
114-
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
115-
Stable Diffusion.
112+
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
113+
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
114+
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
115+
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
116+
sampling like in Stable Diffusion.
116117
solver_type (`str`, defaults to `midpoint`):
117118
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
118119
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -186,7 +187,7 @@ def __init__(
186187
self.init_noise_sigma = 1.0
187188

188189
# settings for DPM-Solver
189-
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
190+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
190191
if algorithm_type == "deis":
191192
self.register_to_config(algorithm_type="dpmsolver++")
192193
else:
@@ -197,7 +198,7 @@ def __init__(
197198
else:
198199
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
199200

200-
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
201+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
201202
raise ValueError(
202203
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
203204
)
@@ -493,10 +494,10 @@ def convert_model_output(
493494
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
494495
)
495496
# DPM-Solver++ needs to solve an integral of the data prediction model.
496-
if self.config.algorithm_type == "dpmsolver++":
497+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
497498
if self.config.prediction_type == "epsilon":
498499
# DPM-Solver and DPM-Solver++ only need the "mean" output.
499-
if self.config.variance_type in ["learned_range"]:
500+
if self.config.variance_type in ["learned", "learned_range"]:
500501
model_output = model_output[:, :3]
501502
sigma = self.sigmas[self.step_index]
502503
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
@@ -517,34 +518,43 @@ def convert_model_output(
517518
x0_pred = self._threshold_sample(x0_pred)
518519

519520
return x0_pred
521+
520522
# DPM-Solver needs to solve an integral of the noise prediction model.
521523
elif self.config.algorithm_type == "dpmsolver":
522524
if self.config.prediction_type == "epsilon":
523525
# DPM-Solver and DPM-Solver++ only need the "mean" output.
524-
if self.config.variance_type in ["learned_range"]:
525-
model_output = model_output[:, :3]
526-
return model_output
526+
if self.config.variance_type in ["learned", "learned_range"]:
527+
epsilon = model_output[:, :3]
528+
else:
529+
epsilon = model_output
527530
elif self.config.prediction_type == "sample":
528531
sigma = self.sigmas[self.step_index]
529532
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
530533
epsilon = (sample - alpha_t * model_output) / sigma_t
531-
return epsilon
532534
elif self.config.prediction_type == "v_prediction":
533535
sigma = self.sigmas[self.step_index]
534536
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
535537
epsilon = alpha_t * model_output + sigma_t * sample
536-
return epsilon
537538
else:
538539
raise ValueError(
539540
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
540541
" `v_prediction` for the DPMSolverSinglestepScheduler."
541542
)
542543

544+
if self.config.thresholding:
545+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
546+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
547+
x0_pred = self._threshold_sample(x0_pred)
548+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
549+
550+
return epsilon
551+
543552
def dpm_solver_first_order_update(
544553
self,
545554
model_output: torch.Tensor,
546555
*args,
547556
sample: torch.Tensor = None,
557+
noise: Optional[torch.Tensor] = None,
548558
**kwargs,
549559
) -> torch.Tensor:
550560
"""
@@ -594,13 +604,21 @@ def dpm_solver_first_order_update(
594604
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
595605
elif self.config.algorithm_type == "dpmsolver":
596606
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
607+
elif self.config.algorithm_type == "sde-dpmsolver++":
608+
assert noise is not None
609+
x_t = (
610+
(sigma_t / sigma_s * torch.exp(-h)) * sample
611+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
612+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
613+
)
597614
return x_t
598615

599616
def singlestep_dpm_solver_second_order_update(
600617
self,
601618
model_output_list: List[torch.Tensor],
602619
*args,
603620
sample: torch.Tensor = None,
621+
noise: Optional[torch.Tensor] = None,
604622
**kwargs,
605623
) -> torch.Tensor:
606624
"""
@@ -688,6 +706,22 @@ def singlestep_dpm_solver_second_order_update(
688706
- (sigma_t * (torch.exp(h) - 1.0)) * D0
689707
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
690708
)
709+
elif self.config.algorithm_type == "sde-dpmsolver++":
710+
assert noise is not None
711+
if self.config.solver_type == "midpoint":
712+
x_t = (
713+
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
714+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
715+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
716+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
717+
)
718+
elif self.config.solver_type == "heun":
719+
x_t = (
720+
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
721+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
722+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
723+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
724+
)
691725
return x_t
692726

693727
def singlestep_dpm_solver_third_order_update(
@@ -800,6 +834,7 @@ def singlestep_dpm_solver_update(
800834
*args,
801835
sample: torch.Tensor = None,
802836
order: int = None,
837+
noise: Optional[torch.Tensor] = None,
803838
**kwargs,
804839
) -> torch.Tensor:
805840
"""
@@ -848,9 +883,9 @@ def singlestep_dpm_solver_update(
848883
)
849884

850885
if order == 1:
851-
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
886+
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
852887
elif order == 2:
853-
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
888+
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
854889
elif order == 3:
855890
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
856891
else:
@@ -894,6 +929,7 @@ def step(
894929
model_output: torch.Tensor,
895930
timestep: int,
896931
sample: torch.Tensor,
932+
generator=None,
897933
return_dict: bool = True,
898934
) -> Union[SchedulerOutput, Tuple]:
899935
"""
@@ -929,6 +965,13 @@ def step(
929965
self.model_outputs[i] = self.model_outputs[i + 1]
930966
self.model_outputs[-1] = model_output
931967

968+
if self.config.algorithm_type == "sde-dpmsolver++":
969+
noise = randn_tensor(
970+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
971+
)
972+
else:
973+
noise = None
974+
932975
order = self.order_list[self.step_index]
933976

934977
# For img2img denoising might start with order>1 which is not possible
@@ -940,9 +983,11 @@ def step(
940983
if order == 1:
941984
self.sample = sample
942985

943-
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
986+
prev_sample = self.singlestep_dpm_solver_update(
987+
self.model_outputs, sample=self.sample, order=order, noise=noise
988+
)
944989

945-
# upon completion increase step index by one
990+
# upon completion increase step index by one, noise=noise
946991
self._step_index += 1
947992

948993
if not return_dict:

tests/schedulers/test_scheduler_dpm_single.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,20 @@ def test_prediction_type(self):
194194
self.check_over_configs(prediction_type=prediction_type)
195195

196196
def test_solver_order_and_type(self):
197-
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
197+
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
198198
for solver_type in ["midpoint", "heun"]:
199199
for order in [1, 2, 3]:
200200
for prediction_type in ["epsilon", "sample"]:
201-
self.check_over_configs(
202-
solver_order=order,
203-
solver_type=solver_type,
204-
prediction_type=prediction_type,
205-
algorithm_type=algorithm_type,
206-
)
201+
if algorithm_type == "sde-dpmsolver++":
202+
if order == 3:
203+
continue
204+
else:
205+
self.check_over_configs(
206+
solver_order=order,
207+
solver_type=solver_type,
208+
prediction_type=prediction_type,
209+
algorithm_type=algorithm_type,
210+
)
207211
sample = self.full_loop(
208212
solver_order=order,
209213
solver_type=solver_type,

0 commit comments

Comments
 (0)