@@ -307,6 +307,41 @@ struct ReciprocalFunctor : public BaseActivationFunctor<T> {
307307 }
308308};
309309
310+ template <typename T>
311+ struct Reciprocal {
312+ HOSTDEVICE ComplexType<T> operator ()(const ComplexType<T>& val) const {
313+ auto both_inf = [](T real, T imag) {
314+ return (std::isinf (real) && std::isinf (imag));
315+ };
316+
317+ auto either_inf = [](T real, T imag) {
318+ return std::isinf (real) || std::isinf (imag);
319+ };
320+
321+ auto either_nan = [](T real, T imag) {
322+ return std::isnan (real) || std::isnan (imag);
323+ };
324+ if (either_nan (val.real , val.imag ) || both_inf (val.real , val.imag )) {
325+ // If either is Nan or both are infinite, return {nan, nan}
326+ return ComplexType<T>(std::numeric_limits<T>::quiet_NaN (),
327+ std::numeric_limits<T>::quiet_NaN ());
328+ } else if (either_inf (val.real , val.imag )) {
329+ // If either is Inf, return {0, 0}
330+ return ComplexType<T>{static_cast <T>(0 ), static_cast <T>(0 )};
331+ }
332+ return static_cast <ComplexType<T>>(1.0 ) / val;
333+ }
334+ };
335+
336+ template <typename T>
337+ struct ReciprocalFunctor <ComplexType<T>>
338+ : public BaseActivationFunctor<ComplexType<T>> {
339+ template <typename Device, typename X, typename Out>
340+ void operator ()(Device d, X x, Out out) const {
341+ out.device (d) = x.unaryExpr (Reciprocal<T>());
342+ }
343+ };
344+
310345template <typename T>
311346struct ReciprocalGradFunctor : public BaseActivationFunctor <T> {
312347 template <typename Device,
@@ -3607,6 +3642,37 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
36073642 }
36083643};
36093644
3645+ template <typename T>
3646+ struct CudaReciprocalFunctor <ComplexType<T>>
3647+ : public BaseActivationFunctor<ComplexType<T>> {
3648+ __device__ __forceinline__ ComplexType<T> operator ()(
3649+ const ComplexType<T> x) const {
3650+ auto both_inf = [](T real, T imag) {
3651+ return (::isinf (real) && ::isinf (imag));
3652+ };
3653+
3654+ auto either_inf = [](T real, T imag) {
3655+ return ::isinf (real) || ::isinf (imag);
3656+ };
3657+
3658+ auto either_nan = [](T real, T imag) {
3659+ return ::isnan (real) || ::isnan (imag);
3660+ };
3661+ if (either_nan (x.real , x.imag ) || both_inf (x.real , x.imag )) {
3662+ // If either is Nan or both are infinite, return {nan, nan}
3663+ if constexpr (std::is_same<T, float >::value) {
3664+ return ComplexType<T>(nanf (" " ), nanf (" " ));
3665+ } else if constexpr (std::is_same<T, double >::value) {
3666+ return ComplexType<T>(nan (" " ), nan (" " ));
3667+ }
3668+ } else if (either_inf (x.real , x.imag )) {
3669+ // If either is Inf, return {0, 0}
3670+ return ComplexType<T>(static_cast <T>(0 ), static_cast <T>(0 ));
3671+ }
3672+ return static_cast <ComplexType<T>>(1.0 ) / x;
3673+ }
3674+ };
3675+
36103676template <typename T>
36113677struct CudaReciprocalGradFunctor : public BaseActivationFunctor <T> {
36123678 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
0 commit comments