Skip to content

Commit 6b610ee

Browse files
authored
[Complex] Fix wrong reciprocal result when meets complex dtype (PaddlePaddle#74290)
* fix reciprocal complex * refine device code with nan(") and nanf(")
1 parent 4d4fdc4 commit 6b610ee

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,41 @@ struct ReciprocalFunctor : public BaseActivationFunctor<T> {
307307
}
308308
};
309309

310+
template <typename T>
311+
struct Reciprocal {
312+
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
313+
auto both_inf = [](T real, T imag) {
314+
return (std::isinf(real) && std::isinf(imag));
315+
};
316+
317+
auto either_inf = [](T real, T imag) {
318+
return std::isinf(real) || std::isinf(imag);
319+
};
320+
321+
auto either_nan = [](T real, T imag) {
322+
return std::isnan(real) || std::isnan(imag);
323+
};
324+
if (either_nan(val.real, val.imag) || both_inf(val.real, val.imag)) {
325+
// If either is Nan or both are infinite, return {nan, nan}
326+
return ComplexType<T>(std::numeric_limits<T>::quiet_NaN(),
327+
std::numeric_limits<T>::quiet_NaN());
328+
} else if (either_inf(val.real, val.imag)) {
329+
// If either is Inf, return {0, 0}
330+
return ComplexType<T>{static_cast<T>(0), static_cast<T>(0)};
331+
}
332+
return static_cast<ComplexType<T>>(1.0) / val;
333+
}
334+
};
335+
336+
template <typename T>
337+
struct ReciprocalFunctor<ComplexType<T>>
338+
: public BaseActivationFunctor<ComplexType<T>> {
339+
template <typename Device, typename X, typename Out>
340+
void operator()(Device d, X x, Out out) const {
341+
out.device(d) = x.unaryExpr(Reciprocal<T>());
342+
}
343+
};
344+
310345
template <typename T>
311346
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
312347
template <typename Device,
@@ -3607,6 +3642,37 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
36073642
}
36083643
};
36093644

3645+
template <typename T>
3646+
struct CudaReciprocalFunctor<ComplexType<T>>
3647+
: public BaseActivationFunctor<ComplexType<T>> {
3648+
__device__ __forceinline__ ComplexType<T> operator()(
3649+
const ComplexType<T> x) const {
3650+
auto both_inf = [](T real, T imag) {
3651+
return (::isinf(real) && ::isinf(imag));
3652+
};
3653+
3654+
auto either_inf = [](T real, T imag) {
3655+
return ::isinf(real) || ::isinf(imag);
3656+
};
3657+
3658+
auto either_nan = [](T real, T imag) {
3659+
return ::isnan(real) || ::isnan(imag);
3660+
};
3661+
if (either_nan(x.real, x.imag) || both_inf(x.real, x.imag)) {
3662+
// If either is Nan or both are infinite, return {nan, nan}
3663+
if constexpr (std::is_same<T, float>::value) {
3664+
return ComplexType<T>(nanf(""), nanf(""));
3665+
} else if constexpr (std::is_same<T, double>::value) {
3666+
return ComplexType<T>(nan(""), nan(""));
3667+
}
3668+
} else if (either_inf(x.real, x.imag)) {
3669+
// If either is Inf, return {0, 0}
3670+
return ComplexType<T>(static_cast<T>(0), static_cast<T>(0));
3671+
}
3672+
return static_cast<ComplexType<T>>(1.0) / x;
3673+
}
3674+
};
3675+
36103676
template <typename T>
36113677
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
36123678
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

test/legacy_test/test_activation_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3952,6 +3952,26 @@ def init_shape(self):
39523952
self.shape = []
39533953

39543954

3955+
class TestReciprocalComplex(unittest.TestCase):
3956+
def test_reciprocal_complex(self):
3957+
for place in get_places():
3958+
x_np = np.array(
3959+
[
3960+
complex(float('inf'), 0),
3961+
complex(0, float('inf')),
3962+
complex(float('inf'), float('inf')),
3963+
complex(0, float('nan')),
3964+
complex(0, 1),
3965+
],
3966+
dtype=np.complex64,
3967+
)
3968+
res_np = np.reciprocal(x_np)
3969+
with paddle.base.dygraph.guard(place):
3970+
x = paddle.to_tensor(x_np, dtype='complex64', place=place)
3971+
res = paddle.reciprocal(x)
3972+
np.testing.assert_allclose(res.numpy(), res_np)
3973+
3974+
39553975
class TestLog(TestActivation):
39563976
def setUp(self):
39573977
self.op_type = "log"

0 commit comments

Comments
 (0)