Skip to content

Commit 07c9a08

Browse files
Add timestep_spacing and steps_offset to schedulers (#3947)
* Add timestep_spacing to DDPM, LMSDiscrete, PNDM. * Remove spurious line. * More easy schedulers. * Add `linspace` to DDIM * Noise sigma for `trailing`. * Add timestep_spacing to DEISMultistepScheduler. Not sure the range is the way it was intended. * Fix: remove line used to debug. * Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC * Fix: convert to numpy. * Use sched. defaults when instantiating from_config For params not present in the original configuration. This makes it possible to switch pipeline schedulers even if they use different timestep_spacing (or any other param). * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Missing args in DPMSolverMultistep * Test: default args not in config * Style * Fix scheduler name in test * Remove duplicated entries * Add test for solver_type This test currently fails in main. When switching from DEIS to UniPC, solver_type is "logrho" (the default value from DEIS), which gets translated to "bh1" by UniPC. This is different to the default value for UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171 * UniPC: use same default for solver_type Fixes a bug when switching from UniPC from another scheduler (i.e., DEIS) that uses a different solver type. The solver is now the same as if we had instantiated the scheduler directly. * do not save use default values * fix more * fix all * fix schedulers * fix more * finish for real * finish for real * flaky tests * Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py * Default steps_offset to 0. * Add missing docstrings * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2837d49 commit 07c9a08

23 files changed

+598
-77
lines changed

src/diffusers/configuration_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ def _get_init_keys(cls):
423423

424424
@classmethod
425425
def extract_init_dict(cls, config_dict, **kwargs):
426+
# Skip keys that were not present in the original config, so default __init__ values were used
427+
used_defaults = config_dict.get("_use_default_values", [])
428+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
429+
426430
# 0. Copy origin config dict
427431
original_dict = dict(config_dict.items())
428432

@@ -544,8 +548,9 @@ def to_json_saveable(value):
544548
return value
545549

546550
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
547-
# Don't save "_ignore_files"
551+
# Don't save "_ignore_files" or "_use_default_values"
548552
config_dict.pop("_ignore_files", None)
553+
config_dict.pop("_use_default_values", None)
549554

550555
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
551556

@@ -599,6 +604,11 @@ def inner_init(self, *args, **kwargs):
599604
if k not in ignore and k not in new_kwargs
600605
}
601606
)
607+
608+
# Take note of the parameters that were not present in the loaded config
609+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
610+
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
611+
602612
new_kwargs = {**config_init_kwargs, **new_kwargs}
603613
getattr(self, "register_to_config")(**new_kwargs)
604614
init(self, *args, **init_kwargs)
@@ -643,6 +653,10 @@ def init(self, *args, **kwargs):
643653
name = fields[i].name
644654
new_kwargs[name] = arg
645655

656+
# Take note of the parameters that were not present in the loaded config
657+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
658+
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
659+
646660
getattr(self, "register_to_config")(**new_kwargs)
647661
original_init(self, *args, **kwargs)
648662

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
302302

303303
self.num_inference_steps = num_inference_steps
304304

305-
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
306-
if self.config.timestep_spacing == "leading":
305+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
306+
if self.config.timestep_spacing == "linspace":
307+
timesteps = (
308+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
309+
.round()[::-1]
310+
.copy()
311+
.astype(np.int64)
312+
)
313+
elif self.config.timestep_spacing == "leading":
307314
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
308315
# creates integer timesteps by multiplying by ratio
309316
# casting to int to avoid issues when num_inference_step is power of 3

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
321321

322322
self.num_inference_steps = num_inference_steps
323323

324-
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
325-
if self.config.timestep_spacing == "leading":
324+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
325+
if self.config.timestep_spacing == "linspace":
326+
timesteps = (
327+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
328+
.round()[::-1]
329+
.copy()
330+
.astype(np.int64)
331+
)
332+
elif self.config.timestep_spacing == "leading":
326333
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
327334
# creates integer timesteps by multiplying by ratio
328335
# casting to int to avoid issues when num_inference_step is power of 3

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
114114
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
115115
sample_max_value (`float`, default `1.0`):
116116
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
117+
timestep_spacing (`str`, default `"leading"`):
118+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
119+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
120+
steps_offset (`int`, default `0`):
121+
an offset added to the inference steps. You can use a combination of `offset=1` and
122+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
123+
stable diffusion.
117124
"""
118125

119126
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -134,6 +141,8 @@ def __init__(
134141
dynamic_thresholding_ratio: float = 0.995,
135142
clip_sample_range: float = 1.0,
136143
sample_max_value: float = 1.0,
144+
timestep_spacing: str = "leading",
145+
steps_offset: int = 0,
137146
):
138147
if trained_betas is not None:
139148
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -228,11 +237,33 @@ def set_timesteps(
228237
)
229238

230239
self.num_inference_steps = num_inference_steps
231-
232-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
233-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
234240
self.custom_timesteps = False
235241

242+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
243+
if self.config.timestep_spacing == "linspace":
244+
timesteps = (
245+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
246+
.round()[::-1]
247+
.copy()
248+
.astype(np.int64)
249+
)
250+
elif self.config.timestep_spacing == "leading":
251+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
252+
# creates integer timesteps by multiplying by ratio
253+
# casting to int to avoid issues when num_inference_step is power of 3
254+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
255+
timesteps += self.config.steps_offset
256+
elif self.config.timestep_spacing == "trailing":
257+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
258+
# creates integer timesteps by multiplying by ratio
259+
# casting to int to avoid issues when num_inference_step is power of 3
260+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
261+
timesteps -= 1
262+
else:
263+
raise ValueError(
264+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
265+
)
266+
236267
self.timesteps = torch.from_numpy(timesteps).to(device)
237268

238269
def _get_variance(self, t, predicted_variance=None, variance_type=None):

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
116116
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
117117
sample_max_value (`float`, default `1.0`):
118118
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
119+
timestep_spacing (`str`, default `"leading"`):
120+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
121+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
122+
steps_offset (`int`, default `0`):
123+
an offset added to the inference steps. You can use a combination of `offset=1` and
124+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
125+
stable diffusion.
119126
"""
120127

121128
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -138,6 +145,8 @@ def __init__(
138145
dynamic_thresholding_ratio: float = 0.995,
139146
clip_sample_range: float = 1.0,
140147
sample_max_value: float = 1.0,
148+
timestep_spacing: str = "leading",
149+
steps_offset: int = 0,
141150
):
142151
if trained_betas is not None:
143152
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -234,11 +243,33 @@ def set_timesteps(
234243
)
235244

236245
self.num_inference_steps = num_inference_steps
237-
238-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
239-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
240246
self.custom_timesteps = False
241247

248+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
249+
if self.config.timestep_spacing == "linspace":
250+
timesteps = (
251+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
252+
.round()[::-1]
253+
.copy()
254+
.astype(np.int64)
255+
)
256+
elif self.config.timestep_spacing == "leading":
257+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
258+
# creates integer timesteps by multiplying by ratio
259+
# casting to int to avoid issues when num_inference_step is power of 3
260+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
261+
timesteps += self.config.steps_offset
262+
elif self.config.timestep_spacing == "trailing":
263+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
264+
# creates integer timesteps by multiplying by ratio
265+
# casting to int to avoid issues when num_inference_step is power of 3
266+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
267+
timesteps -= 1
268+
else:
269+
raise ValueError(
270+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
271+
)
272+
242273
self.timesteps = torch.from_numpy(timesteps).to(device)
243274

244275
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
107107
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108108
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109109
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
110+
timestep_spacing (`str`, default `"linspace"`):
111+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
112+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
113+
steps_offset (`int`, default `0`):
114+
an offset added to the inference steps. You can use a combination of `offset=1` and
115+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
116+
stable diffusion.
110117
"""
111118

112119
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -129,6 +136,8 @@ def __init__(
129136
solver_type: str = "logrho",
130137
lower_order_final: bool = True,
131138
use_karras_sigmas: Optional[bool] = False,
139+
timestep_spacing: str = "linspace",
140+
steps_offset: int = 0,
132141
):
133142
if trained_betas is not None:
134143
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -185,12 +194,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
185194
device (`str` or `torch.device`, optional):
186195
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
187196
"""
188-
timesteps = (
189-
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
190-
.round()[::-1][:-1]
191-
.copy()
192-
.astype(np.int64)
193-
)
197+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
198+
if self.config.timestep_spacing == "linspace":
199+
timesteps = (
200+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
201+
.round()[::-1][:-1]
202+
.copy()
203+
.astype(np.int64)
204+
)
205+
elif self.config.timestep_spacing == "leading":
206+
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
207+
# creates integer timesteps by multiplying by ratio
208+
# casting to int to avoid issues when num_inference_step is power of 3
209+
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
210+
timesteps += self.config.steps_offset
211+
elif self.config.timestep_spacing == "trailing":
212+
step_ratio = self.config.num_train_timesteps / num_inference_steps
213+
# creates integer timesteps by multiplying by ratio
214+
# casting to int to avoid issues when num_inference_step is power of 3
215+
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
216+
timesteps -= 1
217+
else:
218+
raise ValueError(
219+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
220+
)
194221

195222
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
196223
if self.config.use_karras_sigmas:

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
134134
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
135135
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
136136
diffusion ODEs.
137+
timestep_spacing (`str`, default `"linspace"`):
138+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
139+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
140+
steps_offset (`int`, default `0`):
141+
an offset added to the inference steps. You can use a combination of `offset=1` and
142+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
143+
stable diffusion.
137144
"""
138145

139146
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -158,6 +165,8 @@ def __init__(
158165
use_karras_sigmas: Optional[bool] = False,
159166
lambda_min_clipped: float = -float("inf"),
160167
variance_type: Optional[str] = None,
168+
timestep_spacing: str = "linspace",
169+
steps_offset: int = 0,
161170
):
162171
if trained_betas is not None:
163172
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -217,12 +226,29 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
217226
# Clipping the minimum of all lambda(t) for numerical stability.
218227
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
219228
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
220-
timesteps = (
221-
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
222-
.round()[::-1][:-1]
223-
.copy()
224-
.astype(np.int64)
225-
)
229+
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
230+
231+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
232+
if self.config.timestep_spacing == "linspace":
233+
timesteps = (
234+
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
235+
)
236+
elif self.config.timestep_spacing == "leading":
237+
step_ratio = last_timestep // (num_inference_steps + 1)
238+
# creates integer timesteps by multiplying by ratio
239+
# casting to int to avoid issues when num_inference_step is power of 3
240+
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
241+
timesteps += self.config.steps_offset
242+
elif self.config.timestep_spacing == "trailing":
243+
step_ratio = self.config.num_train_timesteps / num_inference_steps
244+
# creates integer timesteps by multiplying by ratio
245+
# casting to int to avoid issues when num_inference_step is power of 3
246+
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
247+
timesteps -= 1
248+
else:
249+
raise ValueError(
250+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
251+
)
226252

227253
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
228254
if self.config.use_karras_sigmas:

0 commit comments

Comments
 (0)