Skip to content

Commit e405264

Browse files
[scheduler] fix some scheduler dtype error (#2992)
Co-authored-by: wangguan <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2494731 commit e405264

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def set_timesteps(
201201
else:
202202
timesteps = torch.from_numpy(timesteps).to(device)
203203

204-
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
204+
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
205205
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
206206

207207
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def set_timesteps(
190190
timesteps = torch.from_numpy(timesteps).to(device)
191191

192192
# interpolate timesteps
193-
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
193+
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
194194
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
195195

196196
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

0 commit comments

Comments
 (0)