Skip to content

Commit 2e60a2c

Browse files
zhengshengningwanglezz
authored andcommitted
[Precision Depth Alignment] paddle.sin and paddle.cos aligns with torch precision. (PaddlePaddle#75503)
* accuracy_stable_sin * accuracy_stable_cos
1 parent 19a45b5 commit 2e60a2c

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,7 +3518,11 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
35183518
const T arg_x) const {
35193519
MPType dout = static_cast<MPType>(arg_dout);
35203520
MPType x = static_cast<MPType>(arg_x);
3521-
return static_cast<T>(-dout * sin(x));
3521+
if constexpr (std::is_same<T, phi::float16>::value) {
3522+
return static_cast<T>(-arg_dout * static_cast<T>(sin(x)));
3523+
} else {
3524+
return static_cast<T>(-dout * sin(x));
3525+
}
35223526
}
35233527

35243528
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
@@ -3853,7 +3857,11 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
38533857
const T arg_x) const {
38543858
MPType dout = static_cast<MPType>(arg_dout);
38553859
MPType x = static_cast<MPType>(arg_x);
3856-
return static_cast<T>(dout * cos(x));
3860+
if constexpr (std::is_same<T, phi::float16>::value) {
3861+
return static_cast<T>(arg_dout * static_cast<T>(cos(x)));
3862+
} else {
3863+
return static_cast<T>(dout * cos(x));
3864+
}
38573865
}
38583866

38593867
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }

0 commit comments

Comments
 (0)