File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed
Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -3518,7 +3518,11 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
35183518 const T arg_x) const {
35193519 MPType dout = static_cast <MPType>(arg_dout);
35203520 MPType x = static_cast <MPType>(arg_x);
3521- return static_cast <T>(-dout * sin (x));
3521+ if constexpr (std::is_same<T, phi::float16>::value) {
3522+ return static_cast <T>(-arg_dout * static_cast <T>(sin (x)));
3523+ } else {
3524+ return static_cast <T>(-dout * sin (x));
3525+ }
35223526 }
35233527
35243528 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
@@ -3853,7 +3857,11 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
38533857 const T arg_x) const {
38543858 MPType dout = static_cast <MPType>(arg_dout);
38553859 MPType x = static_cast <MPType>(arg_x);
3856- return static_cast <T>(dout * cos (x));
3860+ if constexpr (std::is_same<T, phi::float16>::value) {
3861+ return static_cast <T>(arg_dout * static_cast <T>(cos (x)));
3862+ } else {
3863+ return static_cast <T>(dout * cos (x));
3864+ }
38573865 }
38583866
38593867 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
You can’t perform that action at this time.
0 commit comments