File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
pipelines/versatile_diffusion Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -648,7 +648,7 @@ def forward(
648
648
649
649
t_emb = self .time_proj (timesteps )
650
650
651
- # timesteps does not contain any weights and will always return f32 tensors
651
+ # `Timesteps` does not contain any weights and will always return f32 tensors
652
652
# but time_embedding might actually be running in fp16. so we need to cast here.
653
653
# there might be better ways to encapsulate this.
654
654
t_emb = t_emb .to (dtype = self .dtype )
@@ -662,6 +662,10 @@ def forward(
662
662
if self .config .class_embed_type == "timestep" :
663
663
class_labels = self .time_proj (class_labels )
664
664
665
+ # `Timesteps` does not contain any weights and will always return f32 tensors
666
+ # there might be better ways to encapsulate this.
667
+ class_labels = class_labels .to (dtype = self .dtype )
668
+
665
669
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
666
670
667
671
if self .config .class_embeddings_concat :
Original file line number Diff line number Diff line change @@ -745,7 +745,7 @@ def forward(
745
745
746
746
t_emb = self .time_proj (timesteps )
747
747
748
- # timesteps does not contain any weights and will always return f32 tensors
748
+ # `Timesteps` does not contain any weights and will always return f32 tensors
749
749
# but time_embedding might actually be running in fp16. so we need to cast here.
750
750
# there might be better ways to encapsulate this.
751
751
t_emb = t_emb .to (dtype = self .dtype )
@@ -759,6 +759,10 @@ def forward(
759
759
if self .config .class_embed_type == "timestep" :
760
760
class_labels = self .time_proj (class_labels )
761
761
762
+ # `Timesteps` does not contain any weights and will always return f32 tensors
763
+ # there might be better ways to encapsulate this.
764
+ class_labels = class_labels .to (dtype = self .dtype )
765
+
762
766
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
763
767
764
768
if self .config .class_embeddings_concat :
You can’t perform that action at this time.
0 commit comments