Skip to content

Commit fc18839

Browse files
class labels timestep embeddings projection dtype cast (#3137)
This mimics the dtype cast for the standard time embeddings
1 parent f0c74e9 commit fc18839

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def forward(
659659

660660
t_emb = self.time_proj(timesteps)
661661

662-
# timesteps does not contain any weights and will always return f32 tensors
662+
# `Timesteps` does not contain any weights and will always return f32 tensors
663663
# but time_embedding might actually be running in fp16. so we need to cast here.
664664
# there might be better ways to encapsulate this.
665665
t_emb = t_emb.to(dtype=self.dtype)
@@ -673,6 +673,10 @@ def forward(
673673
if self.config.class_embed_type == "timestep":
674674
class_labels = self.time_proj(class_labels)
675675

676+
# `Timesteps` does not contain any weights and will always return f32 tensors
677+
# there might be better ways to encapsulate this.
678+
class_labels = class_labels.to(dtype=sample.dtype)
679+
676680
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
677681

678682
if self.config.class_embeddings_concat:

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def forward(
756756

757757
t_emb = self.time_proj(timesteps)
758758

759-
# timesteps does not contain any weights and will always return f32 tensors
759+
# `Timesteps` does not contain any weights and will always return f32 tensors
760760
# but time_embedding might actually be running in fp16. so we need to cast here.
761761
# there might be better ways to encapsulate this.
762762
t_emb = t_emb.to(dtype=self.dtype)
@@ -770,6 +770,10 @@ def forward(
770770
if self.config.class_embed_type == "timestep":
771771
class_labels = self.time_proj(class_labels)
772772

773+
# `Timesteps` does not contain any weights and will always return f32 tensors
774+
# there might be better ways to encapsulate this.
775+
class_labels = class_labels.to(dtype=sample.dtype)
776+
773777
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
774778

775779
if self.config.class_embeddings_concat:

0 commit comments

Comments
 (0)