Skip to content

Commit f0b26ef

Browse files
committed
class labels timestep embeddings projection dtype cast
This mimics the dtype cast for the standard time embeddings
1 parent 703307e commit f0b26ef

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def forward(
648648

649649
t_emb = self.time_proj(timesteps)
650650

651-
# timesteps does not contain any weights and will always return f32 tensors
651+
# `Timesteps` does not contain any weights and will always return f32 tensors
652652
# but time_embedding might actually be running in fp16. so we need to cast here.
653653
# there might be better ways to encapsulate this.
654654
t_emb = t_emb.to(dtype=self.dtype)
@@ -662,6 +662,10 @@ def forward(
662662
if self.config.class_embed_type == "timestep":
663663
class_labels = self.time_proj(class_labels)
664664

665+
# `Timesteps` does not contain any weights and will always return f32 tensors
666+
# there might be better ways to encapsulate this.
667+
class_labels = class_labels.to(dtype=self.dtype)
668+
665669
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
666670

667671
if self.config.class_embeddings_concat:

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def forward(
745745

746746
t_emb = self.time_proj(timesteps)
747747

748-
# timesteps does not contain any weights and will always return f32 tensors
748+
# `Timesteps` does not contain any weights and will always return f32 tensors
749749
# but time_embedding might actually be running in fp16. so we need to cast here.
750750
# there might be better ways to encapsulate this.
751751
t_emb = t_emb.to(dtype=self.dtype)
@@ -759,6 +759,10 @@ def forward(
759759
if self.config.class_embed_type == "timestep":
760760
class_labels = self.time_proj(class_labels)
761761

762+
# `Timesteps` does not contain any weights and will always return f32 tensors
763+
# there might be better ways to encapsulate this.
764+
class_labels = class_labels.to(dtype=self.dtype)
765+
762766
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
763767

764768
if self.config.class_embeddings_concat:

0 commit comments

Comments
 (0)