File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
pipelines/versatile_diffusion Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -659,7 +659,7 @@ def forward(
659
659
660
660
t_emb = self .time_proj (timesteps )
661
661
662
- # timesteps does not contain any weights and will always return f32 tensors
662
+ # `Timesteps` does not contain any weights and will always return f32 tensors
663
663
# but time_embedding might actually be running in fp16. so we need to cast here.
664
664
# there might be better ways to encapsulate this.
665
665
t_emb = t_emb .to (dtype = self .dtype )
@@ -673,6 +673,10 @@ def forward(
673
673
if self .config .class_embed_type == "timestep" :
674
674
class_labels = self .time_proj (class_labels )
675
675
676
+ # `Timesteps` does not contain any weights and will always return f32 tensors
677
+ # there might be better ways to encapsulate this.
678
+ class_labels = class_labels .to (dtype = sample .dtype )
679
+
676
680
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
677
681
678
682
if self .config .class_embeddings_concat :
Original file line number Diff line number Diff line change @@ -756,7 +756,7 @@ def forward(
756
756
757
757
t_emb = self .time_proj (timesteps )
758
758
759
- # timesteps does not contain any weights and will always return f32 tensors
759
+ # `Timesteps` does not contain any weights and will always return f32 tensors
760
760
# but time_embedding might actually be running in fp16. so we need to cast here.
761
761
# there might be better ways to encapsulate this.
762
762
t_emb = t_emb .to (dtype = self .dtype )
@@ -770,6 +770,10 @@ def forward(
770
770
if self .config .class_embed_type == "timestep" :
771
771
class_labels = self .time_proj (class_labels )
772
772
773
+ # `Timesteps` does not contain any weights and will always return f32 tensors
774
+ # there might be better ways to encapsulate this.
775
+ class_labels = class_labels .to (dtype = sample .dtype )
776
+
773
777
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
774
778
775
779
if self .config .class_embeddings_concat :
You can’t perform that action at this time.
0 commit comments