21
21
import torch
22
22
23
23
from ..configuration_utils import ConfigMixin , register_to_config
24
+ from ..utils import randn_tensor
24
25
from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
25
26
26
27
@@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
70
71
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
71
72
stable-diffusion).
72
73
74
+ We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
75
+ diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
76
+ second-order `sde-dpmsolver++`.
77
+
73
78
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
74
79
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
75
80
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
@@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103
108
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
104
109
`algorithm_type="dpmsolver++`.
105
110
algorithm_type (`str`, default `dpmsolver++`):
106
- the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The ` dpmsolver` type implements the
107
- algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
108
- https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
109
- sampling (e.g. stable-diffusion).
111
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde- dpmsolver` or
112
+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
113
+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
114
+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
110
115
solver_type (`str`, default `midpoint`):
111
116
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
112
117
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -180,7 +185,7 @@ def __init__(
180
185
self .init_noise_sigma = 1.0
181
186
182
187
# settings for DPM-Solver
183
- if algorithm_type not in ["dpmsolver" , "dpmsolver++" ]:
188
+ if algorithm_type not in ["dpmsolver" , "dpmsolver++" , "sde-dpmsolver" , "sde-dpmsolver++" ]:
184
189
if algorithm_type == "deis" :
185
190
self .register_to_config (algorithm_type = "dpmsolver++" )
186
191
else :
@@ -212,7 +217,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
212
217
"""
213
218
# Clipping the minimum of all lambda(t) for numerical stability.
214
219
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
215
- clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .lambda_min_clipped )
220
+ clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .config . lambda_min_clipped )
216
221
timesteps = (
217
222
np .linspace (0 , self .config .num_train_timesteps - 1 - clipped_idx , num_inference_steps + 1 )
218
223
.round ()[::- 1 ][:- 1 ]
@@ -338,10 +343,10 @@ def convert_model_output(
338
343
"""
339
344
340
345
# DPM-Solver++ needs to solve an integral of the data prediction model.
341
- if self .config .algorithm_type == "dpmsolver++" :
346
+ if self .config .algorithm_type in [ "dpmsolver++" , "sde-dpmsolver++" ] :
342
347
if self .config .prediction_type == "epsilon" :
343
348
# DPM-Solver and DPM-Solver++ only need the "mean" output.
344
- if self .config .variance_type in ["learned_range" ]:
349
+ if self .config .variance_type in ["learned" , " learned_range" ]:
345
350
model_output = model_output [:, :3 ]
346
351
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
347
352
x0_pred = (sample - sigma_t * model_output ) / alpha_t
@@ -360,33 +365,42 @@ def convert_model_output(
360
365
x0_pred = self ._threshold_sample (x0_pred )
361
366
362
367
return x0_pred
368
+
363
369
# DPM-Solver needs to solve an integral of the noise prediction model.
364
- elif self .config .algorithm_type == "dpmsolver" :
370
+ elif self .config .algorithm_type in [ "dpmsolver" , "sde-dpmsolver" ] :
365
371
if self .config .prediction_type == "epsilon" :
366
372
# DPM-Solver and DPM-Solver++ only need the "mean" output.
367
- if self .config .variance_type in ["learned_range" ]:
368
- model_output = model_output [:, :3 ]
369
- return model_output
373
+ if self .config .variance_type in ["learned" , "learned_range" ]:
374
+ epsilon = model_output [:, :3 ]
375
+ else :
376
+ epsilon = model_output
370
377
elif self .config .prediction_type == "sample" :
371
378
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
372
379
epsilon = (sample - alpha_t * model_output ) / sigma_t
373
- return epsilon
374
380
elif self .config .prediction_type == "v_prediction" :
375
381
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
376
382
epsilon = alpha_t * model_output + sigma_t * sample
377
- return epsilon
378
383
else :
379
384
raise ValueError (
380
385
f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
381
386
" `v_prediction` for the DPMSolverMultistepScheduler."
382
387
)
383
388
389
+ if self .config .thresholding :
390
+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
391
+ x0_pred = (sample - sigma_t * epsilon ) / alpha_t
392
+ x0_pred = self ._threshold_sample (x0_pred )
393
+ epsilon = (sample - alpha_t * x0_pred ) / sigma_t
394
+
395
+ return epsilon
396
+
384
397
def dpm_solver_first_order_update (
385
398
self ,
386
399
model_output : torch .FloatTensor ,
387
400
timestep : int ,
388
401
prev_timestep : int ,
389
402
sample : torch .FloatTensor ,
403
+ noise : Optional [torch .FloatTensor ] = None ,
390
404
) -> torch .FloatTensor :
391
405
"""
392
406
One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -411,6 +425,20 @@ def dpm_solver_first_order_update(
411
425
x_t = (sigma_t / sigma_s ) * sample - (alpha_t * (torch .exp (- h ) - 1.0 )) * model_output
412
426
elif self .config .algorithm_type == "dpmsolver" :
413
427
x_t = (alpha_t / alpha_s ) * sample - (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
428
+ elif self .config .algorithm_type == "sde-dpmsolver++" :
429
+ assert noise is not None
430
+ x_t = (
431
+ (sigma_t / sigma_s * torch .exp (- h )) * sample
432
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * model_output
433
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
434
+ )
435
+ elif self .config .algorithm_type == "sde-dpmsolver" :
436
+ assert noise is not None
437
+ x_t = (
438
+ (alpha_t / alpha_s ) * sample
439
+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
440
+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
441
+ )
414
442
return x_t
415
443
416
444
def multistep_dpm_solver_second_order_update (
@@ -419,6 +447,7 @@ def multistep_dpm_solver_second_order_update(
419
447
timestep_list : List [int ],
420
448
prev_timestep : int ,
421
449
sample : torch .FloatTensor ,
450
+ noise : Optional [torch .FloatTensor ] = None ,
422
451
) -> torch .FloatTensor :
423
452
"""
424
453
One step for the second-order multistep DPM-Solver.
@@ -470,6 +499,38 @@ def multistep_dpm_solver_second_order_update(
470
499
- (sigma_t * (torch .exp (h ) - 1.0 )) * D0
471
500
- (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
472
501
)
502
+ elif self .config .algorithm_type == "sde-dpmsolver++" :
503
+ assert noise is not None
504
+ if self .config .solver_type == "midpoint" :
505
+ x_t = (
506
+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
507
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
508
+ + 0.5 * (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D1
509
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
510
+ )
511
+ elif self .config .solver_type == "heun" :
512
+ x_t = (
513
+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
514
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
515
+ + (alpha_t * ((1.0 - torch .exp (- 2.0 * h )) / (- 2.0 * h ) + 1.0 )) * D1
516
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
517
+ )
518
+ elif self .config .algorithm_type == "sde-dpmsolver" :
519
+ assert noise is not None
520
+ if self .config .solver_type == "midpoint" :
521
+ x_t = (
522
+ (alpha_t / alpha_s0 ) * sample
523
+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
524
+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D1
525
+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
526
+ )
527
+ elif self .config .solver_type == "heun" :
528
+ x_t = (
529
+ (alpha_t / alpha_s0 ) * sample
530
+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
531
+ - 2.0 * (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
532
+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
533
+ )
473
534
return x_t
474
535
475
536
def multistep_dpm_solver_third_order_update (
@@ -532,6 +593,7 @@ def step(
532
593
model_output : torch .FloatTensor ,
533
594
timestep : int ,
534
595
sample : torch .FloatTensor ,
596
+ generator = None ,
535
597
return_dict : bool = True ,
536
598
) -> Union [SchedulerOutput , Tuple ]:
537
599
"""
@@ -574,12 +636,21 @@ def step(
574
636
self .model_outputs [i ] = self .model_outputs [i + 1 ]
575
637
self .model_outputs [- 1 ] = model_output
576
638
639
+ if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
640
+ noise = randn_tensor (
641
+ model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
642
+ )
643
+ else :
644
+ noise = None
645
+
577
646
if self .config .solver_order == 1 or self .lower_order_nums < 1 or lower_order_final :
578
- prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
647
+ prev_sample = self .dpm_solver_first_order_update (
648
+ model_output , timestep , prev_timestep , sample , noise = noise
649
+ )
579
650
elif self .config .solver_order == 2 or self .lower_order_nums < 2 or lower_order_second :
580
651
timestep_list = [self .timesteps [step_index - 1 ], timestep ]
581
652
prev_sample = self .multistep_dpm_solver_second_order_update (
582
- self .model_outputs , timestep_list , prev_timestep , sample
653
+ self .model_outputs , timestep_list , prev_timestep , sample , noise = noise
583
654
)
584
655
else :
585
656
timestep_list = [self .timesteps [step_index - 2 ], self .timesteps [step_index - 1 ], timestep ]
0 commit comments