@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
483483
484484#endif // PADDLE_CUDA_FP16
485485
486- // Arithmetic operators on ARMv8.2-A CPU
487- #if defined(PADDLE_WITH_NATIVE_FP16)
486+ // Arithmetic operators for float16 on GPU
487+ #if defined(PADDLE_CUDA_FP16)
488+ HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
489+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
490+ return float16 (__hadd (half (a), half (b)));
491+ #else
492+ return float16 (float (a) + float (b));
493+ #endif
494+ }
495+
496+ HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
497+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
498+ return float16 (__hsub (half (a), half (b)));
499+ #else
500+ return float16 (float (a) - float (b));
501+ #endif
502+ }
503+
504+ HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
505+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
506+ return float16 (__hmul (half (a), half (b)));
507+ #else
508+ return float16 (float (a) * float (b));
509+ #endif
510+ }
511+
512+ HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
513+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
514+ // TODO(kexinzhao): check which cuda version starts to support __hdiv
515+ float num = __half2float (half (a));
516+ float denom = __half2float (half (b));
517+ return float16 (num / denom);
518+ #else
519+ return float16 (float (a) / float (b));
520+ #endif
521+ }
522+
523+ HOSTDEVICE inline float16 operator -(const float16& a) {
524+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
525+ return float16 (__hneg (half (a)));
526+ #else
527+ float16 res;
528+ res.x = a.x ^ 0x8000 ;
529+ return res;
530+ #endif
531+ }
532+
533+ HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
534+ a = a + b;
535+ return a;
536+ }
537+
538+ HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
539+ a = a - b;
540+ return a;
541+ }
542+
543+ HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
544+ a = a * b;
545+ return a;
546+ }
547+
548+ HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
549+ a = a / b;
550+ return a;
551+ }
552+
553+ HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
554+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
555+ return __heq (half (a), half (b));
556+ #else
557+ return float (a) == float (b);
558+ #endif
559+ }
560+
561+ HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
562+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563+ return __hne (half (a), half (b));
564+ #else
565+ return float (a) != float (b);
566+ #endif
567+ }
568+
569+ HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
570+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
571+ return __hlt (half (a), half (b));
572+ #else
573+ return float (a) < float (b);
574+ #endif
575+ }
576+
577+ HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
578+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
579+ return __hle (half (a), half (b));
580+ #else
581+ return float (a) <= float (b);
582+ #endif
583+ }
584+
585+ HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
586+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
587+ return __hgt (half (a), half (b));
588+ #else
589+ return float (a) > float (b);
590+ #endif
591+ }
592+
593+ HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
594+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595+ return __hge (half (a), half (b));
596+ #else
597+ return float (a) >= float (b);
598+ #endif
599+ }
600+
601+ // Arithmetic operators for float16 on ARMv8.2-A CPU
602+ #elif defined(PADDLE_WITH_NATIVE_FP16)
488603HOST inline float16 operator +(const float16& a, const float16& b) {
489604 float16 res;
490605 asm volatile (
@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
668783 return (res & 0xffff ) != 0 ;
669784}
670785
671- // Arithmetic operators, software emulated on other CPU
786+ // Arithmetic operators for float16 , software emulated on other CPU
672787#else
673- HOSTDEVICE inline float16 operator +(const float16& a, const float16& b) {
788+ HOST inline float16 operator +(const float16& a, const float16& b) {
674789 return float16 (float (a) + float (b));
675790}
676791
677- HOSTDEVICE inline float16 operator -(const float16& a, const float16& b) {
792+ HOST inline float16 operator -(const float16& a, const float16& b) {
678793 return float16 (float (a) - float (b));
679794}
680795
681- HOSTDEVICE inline float16 operator *(const float16& a, const float16& b) {
796+ HOST inline float16 operator *(const float16& a, const float16& b) {
682797 return float16 (float (a) * float (b));
683798}
684799
685- HOSTDEVICE inline float16 operator /(const float16& a, const float16& b) {
800+ HOST inline float16 operator /(const float16& a, const float16& b) {
686801 return float16 (float (a) / float (b));
687802}
688803
689- HOSTDEVICE inline float16 operator -(const float16& a) {
804+ HOST inline float16 operator -(const float16& a) {
690805 float16 res;
691806 res.x = a.x ^ 0x8000 ;
692807 return res;
693808}
694809
695- HOSTDEVICE inline float16& operator +=(float16& a, const float16& b) {
810+ HOST inline float16& operator +=(float16& a, const float16& b) {
696811 a = float16 (float (a) + float (b));
697812 return a;
698813}
699814
700- HOSTDEVICE inline float16& operator -=(float16& a, const float16& b) {
815+ HOST inline float16& operator -=(float16& a, const float16& b) {
701816 a = float16 (float (a) - float (b));
702817 return a;
703818}
704819
705- HOSTDEVICE inline float16& operator *=(float16& a, const float16& b) {
820+ HOST inline float16& operator *=(float16& a, const float16& b) {
706821 a = float16 (float (a) * float (b));
707822 return a;
708823}
709824
710- HOSTDEVICE inline float16& operator /=(float16& a, const float16& b) {
825+ HOST inline float16& operator /=(float16& a, const float16& b) {
711826 a = float16 (float (a) / float (b));
712827 return a;
713828}
714829
715- HOSTDEVICE inline bool operator ==(const float16& a, const float16& b) {
830+ HOST inline bool operator ==(const float16& a, const float16& b) {
716831 return float (a) == float (b);
717832}
718833
719- HOSTDEVICE inline bool operator !=(const float16& a, const float16& b) {
834+ HOST inline bool operator !=(const float16& a, const float16& b) {
720835 return float (a) != float (b);
721836}
722837
723- HOSTDEVICE inline bool operator <(const float16& a, const float16& b) {
838+ HOST inline bool operator <(const float16& a, const float16& b) {
724839 return float (a) < float (b);
725840}
726841
727- HOSTDEVICE inline bool operator <=(const float16& a, const float16& b) {
842+ HOST inline bool operator <=(const float16& a, const float16& b) {
728843 return float (a) <= float (b);
729844}
730845
731- HOSTDEVICE inline bool operator >(const float16& a, const float16& b) {
846+ HOST inline bool operator >(const float16& a, const float16& b) {
732847 return float (a) > float (b);
733848}
734849
735- HOSTDEVICE inline bool operator >=(const float16& a, const float16& b) {
850+ HOST inline bool operator >=(const float16& a, const float16& b) {
736851 return float (a) >= float (b);
737852}
738853#endif
0 commit comments