22
22
23
23
from ..configuration_utils import ConfigMixin , register_to_config
24
24
from ..utils import deprecate , logging
25
+ from ..utils .torch_utils import randn_tensor
25
26
from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
26
27
27
28
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108
109
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109
110
`algorithm_type="dpmsolver++"`.
110
111
algorithm_type (`str`, defaults to `dpmsolver++`):
111
- Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The ` dpmsolver` type implements the
112
- algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
113
- implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
114
- recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
115
- Stable Diffusion.
112
+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde- dpmsolver++`. The `dpmsolver`
113
+ type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
114
+ `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
115
+ paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
116
+ sampling like in Stable Diffusion.
116
117
solver_type (`str`, defaults to `midpoint`):
117
118
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
118
119
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -186,7 +187,7 @@ def __init__(
186
187
self .init_noise_sigma = 1.0
187
188
188
189
# settings for DPM-Solver
189
- if algorithm_type not in ["dpmsolver" , "dpmsolver++" ]:
190
+ if algorithm_type not in ["dpmsolver" , "dpmsolver++" , "sde-dpmsolver++" ]:
190
191
if algorithm_type == "deis" :
191
192
self .register_to_config (algorithm_type = "dpmsolver++" )
192
193
else :
@@ -197,7 +198,7 @@ def __init__(
197
198
else :
198
199
raise NotImplementedError (f"{ solver_type } is not implemented for { self .__class__ } " )
199
200
200
- if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero" :
201
+ if algorithm_type not in [ "dpmsolver++" , "sde-dpmsolver++" ] and final_sigmas_type == "zero" :
201
202
raise ValueError (
202
203
f"`final_sigmas_type` { final_sigmas_type } is not supported for `algorithm_type` { algorithm_type } . Please chooose `sigma_min` instead."
203
204
)
@@ -493,10 +494,10 @@ def convert_model_output(
493
494
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`" ,
494
495
)
495
496
# DPM-Solver++ needs to solve an integral of the data prediction model.
496
- if self .config .algorithm_type == "dpmsolver++" :
497
+ if self .config .algorithm_type in [ "dpmsolver++" , "sde-dpmsolver++" ] :
497
498
if self .config .prediction_type == "epsilon" :
498
499
# DPM-Solver and DPM-Solver++ only need the "mean" output.
499
- if self .config .variance_type in ["learned_range" ]:
500
+ if self .config .variance_type in ["learned" , " learned_range" ]:
500
501
model_output = model_output [:, :3 ]
501
502
sigma = self .sigmas [self .step_index ]
502
503
alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
@@ -517,34 +518,43 @@ def convert_model_output(
517
518
x0_pred = self ._threshold_sample (x0_pred )
518
519
519
520
return x0_pred
521
+
520
522
# DPM-Solver needs to solve an integral of the noise prediction model.
521
523
elif self .config .algorithm_type == "dpmsolver" :
522
524
if self .config .prediction_type == "epsilon" :
523
525
# DPM-Solver and DPM-Solver++ only need the "mean" output.
524
- if self .config .variance_type in ["learned_range" ]:
525
- model_output = model_output [:, :3 ]
526
- return model_output
526
+ if self .config .variance_type in ["learned" , "learned_range" ]:
527
+ epsilon = model_output [:, :3 ]
528
+ else :
529
+ epsilon = model_output
527
530
elif self .config .prediction_type == "sample" :
528
531
sigma = self .sigmas [self .step_index ]
529
532
alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
530
533
epsilon = (sample - alpha_t * model_output ) / sigma_t
531
- return epsilon
532
534
elif self .config .prediction_type == "v_prediction" :
533
535
sigma = self .sigmas [self .step_index ]
534
536
alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
535
537
epsilon = alpha_t * model_output + sigma_t * sample
536
- return epsilon
537
538
else :
538
539
raise ValueError (
539
540
f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
540
541
" `v_prediction` for the DPMSolverSinglestepScheduler."
541
542
)
542
543
544
+ if self .config .thresholding :
545
+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
546
+ x0_pred = (sample - sigma_t * epsilon ) / alpha_t
547
+ x0_pred = self ._threshold_sample (x0_pred )
548
+ epsilon = (sample - alpha_t * x0_pred ) / sigma_t
549
+
550
+ return epsilon
551
+
543
552
def dpm_solver_first_order_update (
544
553
self ,
545
554
model_output : torch .Tensor ,
546
555
* args ,
547
556
sample : torch .Tensor = None ,
557
+ noise : Optional [torch .Tensor ] = None ,
548
558
** kwargs ,
549
559
) -> torch .Tensor :
550
560
"""
@@ -594,13 +604,21 @@ def dpm_solver_first_order_update(
594
604
x_t = (sigma_t / sigma_s ) * sample - (alpha_t * (torch .exp (- h ) - 1.0 )) * model_output
595
605
elif self .config .algorithm_type == "dpmsolver" :
596
606
x_t = (alpha_t / alpha_s ) * sample - (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
607
+ elif self .config .algorithm_type == "sde-dpmsolver++" :
608
+ assert noise is not None
609
+ x_t = (
610
+ (sigma_t / sigma_s * torch .exp (- h )) * sample
611
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * model_output
612
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
613
+ )
597
614
return x_t
598
615
599
616
def singlestep_dpm_solver_second_order_update (
600
617
self ,
601
618
model_output_list : List [torch .Tensor ],
602
619
* args ,
603
620
sample : torch .Tensor = None ,
621
+ noise : Optional [torch .Tensor ] = None ,
604
622
** kwargs ,
605
623
) -> torch .Tensor :
606
624
"""
@@ -688,6 +706,22 @@ def singlestep_dpm_solver_second_order_update(
688
706
- (sigma_t * (torch .exp (h ) - 1.0 )) * D0
689
707
- (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
690
708
)
709
+ elif self .config .algorithm_type == "sde-dpmsolver++" :
710
+ assert noise is not None
711
+ if self .config .solver_type == "midpoint" :
712
+ x_t = (
713
+ (sigma_t / sigma_s1 * torch .exp (- h )) * sample
714
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
715
+ + 0.5 * (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D1
716
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
717
+ )
718
+ elif self .config .solver_type == "heun" :
719
+ x_t = (
720
+ (sigma_t / sigma_s1 * torch .exp (- h )) * sample
721
+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
722
+ + (alpha_t * ((1.0 - torch .exp (- 2.0 * h )) / (- 2.0 * h ) + 1.0 )) * D1
723
+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
724
+ )
691
725
return x_t
692
726
693
727
def singlestep_dpm_solver_third_order_update (
@@ -800,6 +834,7 @@ def singlestep_dpm_solver_update(
800
834
* args ,
801
835
sample : torch .Tensor = None ,
802
836
order : int = None ,
837
+ noise : Optional [torch .Tensor ] = None ,
803
838
** kwargs ,
804
839
) -> torch .Tensor :
805
840
"""
@@ -848,9 +883,9 @@ def singlestep_dpm_solver_update(
848
883
)
849
884
850
885
if order == 1 :
851
- return self .dpm_solver_first_order_update (model_output_list [- 1 ], sample = sample )
886
+ return self .dpm_solver_first_order_update (model_output_list [- 1 ], sample = sample , noise = noise )
852
887
elif order == 2 :
853
- return self .singlestep_dpm_solver_second_order_update (model_output_list , sample = sample )
888
+ return self .singlestep_dpm_solver_second_order_update (model_output_list , sample = sample , noise = noise )
854
889
elif order == 3 :
855
890
return self .singlestep_dpm_solver_third_order_update (model_output_list , sample = sample )
856
891
else :
@@ -894,6 +929,7 @@ def step(
894
929
model_output : torch .Tensor ,
895
930
timestep : int ,
896
931
sample : torch .Tensor ,
932
+ generator = None ,
897
933
return_dict : bool = True ,
898
934
) -> Union [SchedulerOutput , Tuple ]:
899
935
"""
@@ -929,6 +965,13 @@ def step(
929
965
self .model_outputs [i ] = self .model_outputs [i + 1 ]
930
966
self .model_outputs [- 1 ] = model_output
931
967
968
+ if self .config .algorithm_type == "sde-dpmsolver++" :
969
+ noise = randn_tensor (
970
+ model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
971
+ )
972
+ else :
973
+ noise = None
974
+
932
975
order = self .order_list [self .step_index ]
933
976
934
977
# For img2img denoising might start with order>1 which is not possible
@@ -940,9 +983,11 @@ def step(
940
983
if order == 1 :
941
984
self .sample = sample
942
985
943
- prev_sample = self .singlestep_dpm_solver_update (self .model_outputs , sample = self .sample , order = order )
986
+ prev_sample = self .singlestep_dpm_solver_update (
987
+ self .model_outputs , sample = self .sample , order = order , noise = noise
988
+ )
944
989
945
- # upon completion increase step index by one
990
+ # upon completion increase step index by one, noise=noise
946
991
self ._step_index += 1
947
992
948
993
if not return_dict :
0 commit comments