Skip to content

Commit ae6444d

Browse files
LuChengTHUhari10599
authored andcommitted
Add the SDE variant of DPM-Solver and DPM-Solver++ (huggingface#3344)
* add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo
1 parent 880a83b commit ae6444d

File tree

2 files changed

+98
-23
lines changed

2 files changed

+98
-23
lines changed

Diff for: src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

+87-16
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24+
from ..utils import randn_tensor
2425
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2526

2627

@@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
7071
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
7172
stable-diffusion).
7273
74+
We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
75+
diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
76+
second-order `sde-dpmsolver++`.
77+
7378
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
7479
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
7580
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
@@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103108
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
104109
`algorithm_type="dpmsolver++`.
105110
algorithm_type (`str`, default `dpmsolver++`):
106-
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
107-
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
108-
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
109-
sampling (e.g. stable-diffusion).
111+
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
112+
`sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
113+
the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
114+
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
110115
solver_type (`str`, default `midpoint`):
111116
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
112117
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -180,7 +185,7 @@ def __init__(
180185
self.init_noise_sigma = 1.0
181186

182187
# settings for DPM-Solver
183-
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
188+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
184189
if algorithm_type == "deis":
185190
self.register_to_config(algorithm_type="dpmsolver++")
186191
else:
@@ -212,7 +217,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
212217
"""
213218
# Clipping the minimum of all lambda(t) for numerical stability.
214219
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
215-
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
220+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
216221
timesteps = (
217222
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
218223
.round()[::-1][:-1]
@@ -338,10 +343,10 @@ def convert_model_output(
338343
"""
339344

340345
# DPM-Solver++ needs to solve an integral of the data prediction model.
341-
if self.config.algorithm_type == "dpmsolver++":
346+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
342347
if self.config.prediction_type == "epsilon":
343348
# DPM-Solver and DPM-Solver++ only need the "mean" output.
344-
if self.config.variance_type in ["learned_range"]:
349+
if self.config.variance_type in ["learned", "learned_range"]:
345350
model_output = model_output[:, :3]
346351
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
347352
x0_pred = (sample - sigma_t * model_output) / alpha_t
@@ -360,33 +365,42 @@ def convert_model_output(
360365
x0_pred = self._threshold_sample(x0_pred)
361366

362367
return x0_pred
368+
363369
# DPM-Solver needs to solve an integral of the noise prediction model.
364-
elif self.config.algorithm_type == "dpmsolver":
370+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
365371
if self.config.prediction_type == "epsilon":
366372
# DPM-Solver and DPM-Solver++ only need the "mean" output.
367-
if self.config.variance_type in ["learned_range"]:
368-
model_output = model_output[:, :3]
369-
return model_output
373+
if self.config.variance_type in ["learned", "learned_range"]:
374+
epsilon = model_output[:, :3]
375+
else:
376+
epsilon = model_output
370377
elif self.config.prediction_type == "sample":
371378
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
372379
epsilon = (sample - alpha_t * model_output) / sigma_t
373-
return epsilon
374380
elif self.config.prediction_type == "v_prediction":
375381
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
376382
epsilon = alpha_t * model_output + sigma_t * sample
377-
return epsilon
378383
else:
379384
raise ValueError(
380385
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
381386
" `v_prediction` for the DPMSolverMultistepScheduler."
382387
)
383388

389+
if self.config.thresholding:
390+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
391+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
392+
x0_pred = self._threshold_sample(x0_pred)
393+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
394+
395+
return epsilon
396+
384397
def dpm_solver_first_order_update(
385398
self,
386399
model_output: torch.FloatTensor,
387400
timestep: int,
388401
prev_timestep: int,
389402
sample: torch.FloatTensor,
403+
noise: Optional[torch.FloatTensor] = None,
390404
) -> torch.FloatTensor:
391405
"""
392406
One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -411,6 +425,20 @@ def dpm_solver_first_order_update(
411425
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
412426
elif self.config.algorithm_type == "dpmsolver":
413427
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
428+
elif self.config.algorithm_type == "sde-dpmsolver++":
429+
assert noise is not None
430+
x_t = (
431+
(sigma_t / sigma_s * torch.exp(-h)) * sample
432+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
433+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
434+
)
435+
elif self.config.algorithm_type == "sde-dpmsolver":
436+
assert noise is not None
437+
x_t = (
438+
(alpha_t / alpha_s) * sample
439+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
440+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
441+
)
414442
return x_t
415443

416444
def multistep_dpm_solver_second_order_update(
@@ -419,6 +447,7 @@ def multistep_dpm_solver_second_order_update(
419447
timestep_list: List[int],
420448
prev_timestep: int,
421449
sample: torch.FloatTensor,
450+
noise: Optional[torch.FloatTensor] = None,
422451
) -> torch.FloatTensor:
423452
"""
424453
One step for the second-order multistep DPM-Solver.
@@ -470,6 +499,38 @@ def multistep_dpm_solver_second_order_update(
470499
- (sigma_t * (torch.exp(h) - 1.0)) * D0
471500
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
472501
)
502+
elif self.config.algorithm_type == "sde-dpmsolver++":
503+
assert noise is not None
504+
if self.config.solver_type == "midpoint":
505+
x_t = (
506+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
507+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
508+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
509+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
510+
)
511+
elif self.config.solver_type == "heun":
512+
x_t = (
513+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
514+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
515+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
516+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
517+
)
518+
elif self.config.algorithm_type == "sde-dpmsolver":
519+
assert noise is not None
520+
if self.config.solver_type == "midpoint":
521+
x_t = (
522+
(alpha_t / alpha_s0) * sample
523+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
524+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
525+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
526+
)
527+
elif self.config.solver_type == "heun":
528+
x_t = (
529+
(alpha_t / alpha_s0) * sample
530+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
531+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
532+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
533+
)
473534
return x_t
474535

475536
def multistep_dpm_solver_third_order_update(
@@ -532,6 +593,7 @@ def step(
532593
model_output: torch.FloatTensor,
533594
timestep: int,
534595
sample: torch.FloatTensor,
596+
generator=None,
535597
return_dict: bool = True,
536598
) -> Union[SchedulerOutput, Tuple]:
537599
"""
@@ -574,12 +636,21 @@ def step(
574636
self.model_outputs[i] = self.model_outputs[i + 1]
575637
self.model_outputs[-1] = model_output
576638

639+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
640+
noise = randn_tensor(
641+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
642+
)
643+
else:
644+
noise = None
645+
577646
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
578-
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
647+
prev_sample = self.dpm_solver_first_order_update(
648+
model_output, timestep, prev_timestep, sample, noise=noise
649+
)
579650
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
580651
timestep_list = [self.timesteps[step_index - 1], timestep]
581652
prev_sample = self.multistep_dpm_solver_second_order_update(
582-
self.model_outputs, timestep_list, prev_timestep, sample
653+
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
583654
)
584655
else:
585656
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]

Diff for: tests/schedulers/test_scheduler_dpm_multi.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,20 @@ def test_prediction_type(self):
167167
self.check_over_configs(prediction_type=prediction_type)
168168

169169
def test_solver_order_and_type(self):
170-
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
170+
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
171171
for solver_type in ["midpoint", "heun"]:
172172
for order in [1, 2, 3]:
173173
for prediction_type in ["epsilon", "sample"]:
174-
self.check_over_configs(
175-
solver_order=order,
176-
solver_type=solver_type,
177-
prediction_type=prediction_type,
178-
algorithm_type=algorithm_type,
179-
)
174+
if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
175+
if order == 3:
176+
continue
177+
else:
178+
self.check_over_configs(
179+
solver_order=order,
180+
solver_type=solver_type,
181+
prediction_type=prediction_type,
182+
algorithm_type=algorithm_type,
183+
)
180184
sample = self.full_loop(
181185
solver_order=order,
182186
solver_type=solver_type,

0 commit comments

Comments
 (0)