Skip to content

[Cont'd] Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step #8269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
dd8c2d9
Add the SDE variant of DPM-Solver and DPM-Solver++ to DPM Single Step
cmdr2 Jul 25, 2023
6d45889
Merge remote-tracking branch 'cmdr2/sde-solver' into sde-dpmsolver-si…
tolgacangoz May 24, 2024
1b03f42
Fix `noise` parameter
tolgacangoz May 24, 2024
30bf51a
singlestep seems to prefer s1, s2 in 2., 3. orders rather than s0 lik…
tolgacangoz May 24, 2024
a21e518
refactor: Move randn_tensor import to torch_utils module
tolgacangoz May 24, 2024
894ffb0
make style
tolgacangoz May 24, 2024
ea7db0d
Refactor deprecation warning and error handling logic on `algorithm_t…
tolgacangoz May 24, 2024
4123c86
Fix typos
tolgacangoz May 24, 2024
0648ee2
Revert "Fix typos"
tolgacangoz May 24, 2024
9594a86
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz May 28, 2024
b94e859
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz May 31, 2024
4609366
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 18, 2024
ea4a1df
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 18, 2024
c228fc3
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 18, 2024
cfbdc19
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 19, 2024
cda2e6c
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 20, 2024
552e19a
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 23, 2024
16acba1
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 24, 2024
16c07fd
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jun 28, 2024
7070f02
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jul 8, 2024
7406ee6
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jul 11, 2024
b72be84
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jul 11, 2024
d59a5ea
Remove `sde-dpmsolver`
tolgacangoz Jul 11, 2024
bb236b8
Merge branch 'main' into sde-dpmsolver-single-step
tolgacangoz Jul 11, 2024
53f4b89
Merge branch 'sde-dpmsolver-single-step' of github.com:tolgacangoz/di…
tolgacangoz Jul 11, 2024
890e93f
chore: Update deprecation message for algorithm_type `dpmsolver`
tolgacangoz Jul 11, 2024
937689c
Refactor DPMSolverSinglestepSchedulerTest to exclude sde-dpmsolver fr…
tolgacangoz Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


Expand Down Expand Up @@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
Stable Diffusion.
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
Expand Down Expand Up @@ -186,7 +187,7 @@ def __init__(
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
Expand All @@ -197,7 +198,7 @@ def __init__(
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")

if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
)
Expand Down Expand Up @@ -493,10 +494,10 @@ def convert_model_output(
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
Expand All @@ -517,34 +518,43 @@ def convert_model_output(
x0_pred = self._threshold_sample(x0_pred)

return x0_pred

# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverSinglestepScheduler."
)

if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t

return epsilon

def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -594,13 +604,21 @@ def dpm_solver_first_order_update(
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def singlestep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -688,6 +706,22 @@ def singlestep_dpm_solver_second_order_update(
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def singlestep_dpm_solver_third_order_update(
Expand Down Expand Up @@ -800,6 +834,7 @@ def singlestep_dpm_solver_update(
*args,
sample: torch.Tensor = None,
order: int = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -848,9 +883,9 @@ def singlestep_dpm_solver_update(
)

if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
else:
Expand Down Expand Up @@ -894,6 +929,7 @@ def step(
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -929,6 +965,13 @@ def step(
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output

if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
else:
noise = None

order = self.order_list[self.step_index]

# For img2img denoising might start with order>1 which is not possible
Expand All @@ -940,9 +983,11 @@ def step(
if order == 1:
self.sample = sample

prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, sample=self.sample, order=order, noise=noise
)

# upon completion increase step index by one
# upon completion increase step index by one, noise=noise
self._step_index += 1

if not return_dict:
Expand Down
18 changes: 11 additions & 7 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,20 @@ def test_prediction_type(self):
self.check_over_configs(prediction_type=prediction_type)

def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
if algorithm_type == "sde-dpmsolver++":
if order == 3:
continue
else:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
Expand Down
Loading