diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b2814356939b..7a6314d00633 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -648,7 +648,7 @@ def forward( t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) @@ -662,6 +662,10 @@ def forward( if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) if self.config.class_embeddings_concat: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4377be1181a8..80eb40dc99bf 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -745,7 +745,7 @@ def forward( t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) @@ -759,6 +759,10 @@ def forward( if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) if self.config.class_embeddings_concat: