Skip to content

Commit fb8e6a5

Browse files
Correct controlnet out of list error (huggingface#3928)
* Correct controlnet out of list error * Apply suggestions from code review * correct tests * correct tests * fix * test all * Apply suggestions from code review * test all * test all * Apply suggestions from code review * Apply suggestions from code review * fix more tests * Fix more * Apply suggestions from code review * finish * Apply suggestions from code review * Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py * finish
1 parent 7e2cc71 commit fb8e6a5

8 files changed

+43
-14
lines changed

pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,9 +947,9 @@ def __call__(
947947

948948
# 7.1 Create tensor stating which controlnets to keep
949949
controlnet_keep = []
950-
for i in range(num_inference_steps):
950+
for i in range(len(timesteps)):
951951
keeps = [
952-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
952+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
953953
for s, e in zip(control_guidance_start, control_guidance_end)
954954
]
955955
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,9 @@ def __call__(
10401040

10411041
# 7.1 Create tensor stating which controlnets to keep
10421042
controlnet_keep = []
1043-
for i in range(num_inference_steps):
1043+
for i in range(len(timesteps)):
10441044
keeps = [
1045-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1045+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
10461046
for s, e in zip(control_guidance_start, control_guidance_end)
10471047
]
10481048
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,9 +1275,9 @@ def __call__(
12751275

12761276
# 7.1 Create tensor stating which controlnets to keep
12771277
controlnet_keep = []
1278-
for i in range(num_inference_steps):
1278+
for i in range(len(timesteps)):
12791279
keeps = [
1280-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1280+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
12811281
for s, e in zip(control_guidance_start, control_guidance_end)
12821282
]
12831283
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __call__(
374374
# predicted_original_sample instead of the noise_pred. So we need to compute the
375375
# predicted_original_sample here if we are using a karras style scheduler.
376376
if scheduler_is_in_sigma_space:
377-
step_index = (self.scheduler.timesteps == t).nonzero().item()
377+
step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
378378
sigma = self.scheduler.sigmas[step_index]
379379
noise_pred = latent_model_input - sigma * noise_pred
380380

schedulers/scheduling_deis_multistep.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
103103
lower_order_final (`bool`, default `True`):
104104
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
105105
find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
106-
106+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
107+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
107110
"""
108111

109112
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -125,6 +128,7 @@ def __init__(
125128
algorithm_type: str = "deis",
126129
solver_type: str = "logrho",
127130
lower_order_final: bool = True,
131+
use_karras_sigmas: Optional[bool] = False,
128132
):
129133
if trained_betas is not None:
130134
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -188,6 +192,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
188192
.astype(np.int64)
189193
)
190194

195+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
196+
if self.config.use_karras_sigmas:
197+
log_sigmas = np.log(sigmas)
198+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
199+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
200+
timesteps = np.flip(timesteps).copy().astype(np.int64)
201+
202+
self.sigmas = torch.from_numpy(sigmas)
203+
191204
# when num_inference_steps == num_train_timesteps, we can end up with
192205
# duplicates in timesteps.
193206
_, unique_indices = np.unique(timesteps, return_index=True)

schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def __init__(
203203
self.timesteps = torch.from_numpy(timesteps)
204204
self.model_outputs = [None] * solver_order
205205
self.lower_order_nums = 0
206-
self.use_karras_sigmas = use_karras_sigmas
207206

208207
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
209208
"""
@@ -225,13 +224,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
225224
.astype(np.int64)
226225
)
227226

228-
if self.use_karras_sigmas:
229-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
227+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
228+
if self.config.use_karras_sigmas:
230229
log_sigmas = np.log(sigmas)
231230
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
232231
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
233232
timesteps = np.flip(timesteps).copy().astype(np.int64)
234233

234+
self.sigmas = torch.from_numpy(sigmas)
235+
235236
# when num_inference_steps == num_train_timesteps, we can end up with
236237
# duplicates in timesteps.
237238
_, unique_indices = np.unique(timesteps, return_index=True)

schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def __init__(
202202
self.model_outputs = [None] * solver_order
203203
self.sample = None
204204
self.order_list = self.get_order_list(num_train_timesteps)
205-
self.use_karras_sigmas = use_karras_sigmas
206205

207206
def get_order_list(self, num_inference_steps: int) -> List[int]:
208207
"""
@@ -259,13 +258,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
259258
.astype(np.int64)
260259
)
261260

262-
if self.use_karras_sigmas:
263-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
261+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
262+
if self.config.use_karras_sigmas:
264263
log_sigmas = np.log(sigmas)
265264
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
266265
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
267266
timesteps = np.flip(timesteps).copy().astype(np.int64)
268267

268+
self.sigmas = torch.from_numpy(sigmas)
269+
269270
self.timesteps = torch.from_numpy(timesteps).to(device)
270271
self.model_outputs = [None] * self.config.solver_order
271272
self.sample = None

schedulers/scheduling_unipc_multistep.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
117117
by disable the corrector at the first few steps (e.g., disable_corrector=[0])
118118
solver_p (`SchedulerMixin`, default `None`):
119119
can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
120+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
121+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
122+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
123+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
120124
"""
121125

122126
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -140,6 +144,7 @@ def __init__(
140144
lower_order_final: bool = True,
141145
disable_corrector: List[int] = [],
142146
solver_p: SchedulerMixin = None,
147+
use_karras_sigmas: Optional[bool] = False,
143148
):
144149
if trained_betas is not None:
145150
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -201,6 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
201206
.astype(np.int64)
202207
)
203208

209+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
210+
if self.config.use_karras_sigmas:
211+
log_sigmas = np.log(sigmas)
212+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
213+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
214+
timesteps = np.flip(timesteps).copy().astype(np.int64)
215+
216+
self.sigmas = torch.from_numpy(sigmas)
217+
204218
# when num_inference_steps == num_train_timesteps, we can end up with
205219
# duplicates in timesteps.
206220
_, unique_indices = np.unique(timesteps, return_index=True)

0 commit comments

Comments
 (0)