Skip to content

Commit 526827c

Browse files
authored
Fix scheduler type mismatch (#3041)
When doing generation manually and using guidance_scale as a static argument.
1 parent cb63feb commit 526827c

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def _generate(
245245
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
246246
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
247247

248+
# Ensure model output will be `float32` before going into the scheduler
249+
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
250+
248251
latents_shape = (
249252
batch_size,
250253
self.unet.config.in_channels,

0 commit comments

Comments
 (0)