4040#endif // __CUDACC__
4141#endif // CINN_WITH_CUDA
4242
43+ #ifdef CINN_WITH_HIP
44+ #include < hip/hip_runtime.h>
45+ #if defined(__HIPCC__)
46+ #define __HIP_PLATFORM_AMD__
47+ #include < hip/hip_fp16.h>
48+ #define CINN_HIP_FP16
49+ #endif
50+ #endif
51+
4352#ifdef __cplusplus
4453#ifndef _WIN32
4554#define CINN_ALIGN (x ) __attribute__((aligned(x)))
@@ -83,9 +92,9 @@ struct CINN_ALIGN(2) float16 {
8392 ~float16 () = default ;
8493
8594// Constructors
86- #ifdef CINN_CUDA_FP16
95+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
8796 __host__ __device__ inline explicit float16 (const half& h) {
88- #if ( CUDA_VERSION >= 9000)
97+ #if defined(CINN_CUDA_FP16) && ( CUDA_VERSION >= 9000) || defined(CINN_HIP_FP16 )
8998 x = reinterpret_cast <__half_raw*>(const_cast <half*>(&h))->x ;
9099#else
91100 x = h.x ;
@@ -94,7 +103,9 @@ struct CINN_ALIGN(2) float16 {
94103#endif // CINN_CUDA_FP16
95104
96105 __host__ __device__ inline explicit float16 (float val) {
97- #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
106+ #if defined(CINN_CUDA_FP16) && \
107+ (defined (__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 ) || \
108+ defined (CINN_HIP_FP16)
98109 half tmp = __float2half (val);
99110 x = *reinterpret_cast <uint16_t *>(&tmp);
100111
@@ -129,9 +140,9 @@ struct CINN_ALIGN(2) float16 {
129140 : x (float16 (static_cast <float >(val)).x ) {}
130141
131142// Assignment operators
132- #ifdef CINN_CUDA_FP16
143+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
133144 __host__ __device__ inline float16& operator =(const half& rhs) {
134- #if CUDA_VERSION >= 9000
145+ #if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
135146 x = reinterpret_cast <__half_raw*>(const_cast <half*>(&rhs))->x ;
136147#else
137148 x = rhs.x ;
@@ -196,9 +207,9 @@ struct CINN_ALIGN(2) float16 {
196207 }
197208
198209// Conversion operators
199- #ifdef CINN_CUDA_FP16
210+ #if defined( CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
200211 __host__ __device__ inline half to_half () const {
201- #if CUDA_VERSION >= 9000
212+ #if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
202213 __half_raw h;
203214 h.x = x;
204215 return half (h);
@@ -211,7 +222,9 @@ struct CINN_ALIGN(2) float16 {
211222#endif // CINN_CUDA_FP16
212223
213224 __host__ __device__ inline operator float () const {
214- #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
225+ #if defined(CINN_CUDA_FP16) && \
226+ (defined (__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 ) || \
227+ defined (CINN_HIP_FP16)
215228 half tmp = *reinterpret_cast <const half*>(this );
216229 return __half2float (tmp);
217230
@@ -344,9 +357,9 @@ struct CINN_ALIGN(4) float162 {
344357// CUDA 9.0 regarding the half data type.
345358// ROCM has built-in arithmetic operators as not defined
346359// __HIP_NO_HALF_OPERATORS__
347- #if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
360+ #if ( defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000) || defined(CINN_HIP_FP16)
348361__device__ inline half operator +(const half& a, const half& b) {
349- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
362+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
350363 return __hadd (a, b);
351364#else
352365 float res = static_cast <float >(float16 (a)) + static_cast <float >(float16 (b));
@@ -355,7 +368,7 @@ __device__ inline half operator+(const half& a, const half& b) {
355368}
356369
357370__device__ inline half operator -(const half& a, const half& b) {
358- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
371+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
359372 return __hsub (a, b);
360373#else
361374 float res = static_cast <float >(float16 (a)) - static_cast <float >(float16 (b));
@@ -364,7 +377,7 @@ __device__ inline half operator-(const half& a, const half& b) {
364377}
365378
366379__device__ inline half operator *(const half& a, const half& b) {
367- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
380+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
368381 return __hmul (a, b);
369382#else
370383 float res = static_cast <float >(float16 (a)) * static_cast <float >(float16 (b));
@@ -373,7 +386,7 @@ __device__ inline half operator*(const half& a, const half& b) {
373386}
374387
375388__device__ inline half operator /(const half& a, const half& b) {
376- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
389+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
377390 float num = __half2float (a);
378391 float denom = __half2float (b);
379392 return __float2half (num / denom);
@@ -384,14 +397,15 @@ __device__ inline half operator/(const half& a, const half& b) {
384397}
385398
386399__device__ inline half operator -(const half& a) {
387- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
400+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
388401 return __hneg (a);
389402#else
390403 float res = -static_cast <float >(float16 (a));
391404 return float16 (res).to_half ();
392405#endif
393406}
394407
408+ #ifndef CINN_WITH_HIP
395409__device__ inline half& operator +=(half& a, const half& b) { // NOLINT
396410 a = a + b;
397411 return a;
@@ -411,49 +425,50 @@ __device__ inline half& operator/=(half& a, const half& b) { // NOLINT
411425 a = a / b;
412426 return a;
413427}
428+ #endif
414429
415430__device__ inline bool operator ==(const half& a, const half& b) {
416- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
431+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
417432 return __heq (a, b);
418433#else
419434 return static_cast <float >(float16 (a)) == static_cast <float >(float16 (b));
420435#endif
421436}
422437
423438__device__ inline bool operator !=(const half& a, const half& b) {
424- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
439+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
425440 return __hne (a, b);
426441#else
427442 return static_cast <float >(float16 (a)) != static_cast <float >(float16 (b));
428443#endif
429444}
430445
431446__device__ inline bool operator <(const half& a, const half& b) {
432- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
447+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
433448 return __hlt (a, b);
434449#else
435450 return static_cast <float >(float16 (a)) < static_cast <float >(float16 (b));
436451#endif
437452}
438453
439454__device__ inline bool operator <=(const half& a, const half& b) {
440- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
455+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
441456 return __hle (a, b);
442457#else
443458 return static_cast <float >(float16 (a)) <= static_cast <float >(float16 (b));
444459#endif
445460}
446461
447462__device__ inline bool operator >(const half& a, const half& b) {
448- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
463+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
449464 return __hgt (a, b);
450465#else
451466 return static_cast <float >(float16 (a)) > static_cast <float >(float16 (b));
452467#endif
453468}
454469
455470__device__ inline bool operator >=(const half& a, const half& b) {
456- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
471+ #if ( defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
457472 return __hge (a, b);
458473#else
459474 return static_cast <float >(float16 (a)) >= static_cast <float >(float16 (b));
@@ -465,7 +480,9 @@ __device__ inline bool operator>=(const half& a, const half& b) {
465480// Arithmetic operators for float16 on GPU
466481__host__ __device__ inline float16 operator +(const float16& a,
467482 const float16& b) {
468- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
483+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
484+ __CUDA_ARCH__ >= 530 ) || \
485+ defined (CINN_HIP_FP16)
469486 return float16 (__hadd (a.to_half (), b.to_half ()));
470487#else
471488 return float16 (static_cast <float >(a) + static_cast <float >(b));
@@ -474,7 +491,9 @@ __host__ __device__ inline float16 operator+(const float16& a,
474491
475492__host__ __device__ inline float16 operator -(const float16& a,
476493 const float16& b) {
477- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
494+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
495+ __CUDA_ARCH__ >= 530 ) || \
496+ defined (CINN_HIP_FP16)
478497 return float16 (__hsub (a.to_half (), b.to_half ()));
479498#else
480499 return float16 (static_cast <float >(a) - static_cast <float >(b));
@@ -483,7 +502,9 @@ __host__ __device__ inline float16 operator-(const float16& a,
483502
484503__host__ __device__ inline float16 operator *(const float16& a,
485504 const float16& b) {
486- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
505+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
506+ __CUDA_ARCH__ >= 530 ) || \
507+ defined (CINN_HIP_FP16)
487508 return float16 (__hmul (a.to_half (), b.to_half ()));
488509#else
489510 return float16 (static_cast <float >(a) * static_cast <float >(b));
@@ -492,7 +513,9 @@ __host__ __device__ inline float16 operator*(const float16& a,
492513
493514__host__ __device__ inline float16 operator /(const float16& a,
494515 const float16& b) {
495- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
516+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
517+ __CUDA_ARCH__ >= 530 ) || \
518+ defined (CINN_HIP_FP16)
496519 // TODO(kexinzhao): check which cuda version starts to support __hdiv
497520 float num = __half2float (a.to_half ());
498521 float denom = __half2float (b.to_half ());
@@ -503,7 +526,9 @@ __host__ __device__ inline float16 operator/(const float16& a,
503526}
504527
505528__host__ __device__ inline float16 operator -(const float16& a) {
506- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
529+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
530+ __CUDA_ARCH__ >= 530 ) || \
531+ defined (CINN_HIP_FP16)
507532 return float16 (__hneg (a.to_half ()));
508533#else
509534 float16 res;
@@ -537,47 +562,59 @@ __host__ __device__ inline float16& operator/=(float16& a, // NOLINT
537562}
538563
539564__host__ __device__ inline bool operator ==(const float16& a, const float16& b) {
540- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
565+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
566+ __CUDA_ARCH__ >= 530 ) || \
567+ defined (CINN_HIP_FP16)
541568 return __heq (a.to_half (), b.to_half ());
542569#else
543570 return static_cast <float >(a) == static_cast <float >(b);
544571#endif
545572}
546573
547574__host__ __device__ inline bool operator !=(const float16& a, const float16& b) {
548- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
575+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
576+ __CUDA_ARCH__ >= 530 ) || \
577+ defined (CINN_HIP_FP16)
549578 return __hne (a.to_half (), b.to_half ());
550579#else
551580 return static_cast <float >(a) != static_cast <float >(b);
552581#endif
553582}
554583
555584__host__ __device__ inline bool operator <(const float16& a, const float16& b) {
556- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
585+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
586+ __CUDA_ARCH__ >= 530 ) || \
587+ defined (CINN_HIP_FP16)
557588 return __hlt (a.to_half (), b.to_half ());
558589#else
559590 return static_cast <float >(a) < static_cast <float >(b);
560591#endif
561592}
562593
563594__host__ __device__ inline bool operator <=(const float16& a, const float16& b) {
564- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
596+ __CUDA_ARCH__ >= 530 ) || \
597+ defined (CINN_HIP_FP16)
565598 return __hle (a.to_half (), b.to_half ());
566599#else
567600 return static_cast <float >(a) <= static_cast <float >(b);
568601#endif
569602}
570603
571604__host__ __device__ inline bool operator >(const float16& a, const float16& b) {
572- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
605+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
606+ __CUDA_ARCH__ >= 530 ) || \
607+ defined (CINN_HIP_FP16)
573608 return __hgt (a.to_half (), b.to_half ());
574609#else
575610 return static_cast <float >(a) > static_cast <float >(b);
576611#endif
577612}
578613
579614__host__ __device__ inline bool operator >=(const float16& a, const float16& b) {
580- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
615+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
616+ __CUDA_ARCH__ >= 530 ) || \
617+ defined (CINN_HIP_FP16)
581618 return __hge (a.to_half (), b.to_half ());
582619#else
583620 return static_cast <float >(a) >= static_cast <float >(b);
@@ -592,7 +629,9 @@ __host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
592629}
593630
594631__host__ __device__ inline bool (isnan)(const float16& a) {
595- #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
632+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
633+ __CUDA_ARCH__ >= 530 ) || \
634+ defined (CINN_HIP_FP16)
596635 return __hisnan (a.to_half ());
597636#else
598637 return (a.x & 0x7fff ) > 0x7c00 ;
@@ -608,7 +647,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) {
608647}
609648
610649__host__ __device__ inline float16 (abs)(const float16& a) {
611- #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
650+ #if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
651+ __CUDA_ARCH__ >= 530 ) || \
652+ defined (CINN_HIP_FP16)
612653 return static_cast <float16>(__habs (a.to_half ()));
613654#else
614655 return static_cast <float16>(fabsf (static_cast <float >(a)));
@@ -670,4 +711,44 @@ __host__ __device__ inline cinn::common::float16 min(
670711}
671712#endif // __cplusplus && CINN_CUDA_FP16
672713
714+ // Note: HIP does not support half-float shuffles.
715+ #if defined(CINN_HIP_FP16)
716+ __device__ inline cinn::common::float16 __shfl (cinn::common::float16 var,
717+ int srcLane,
718+ int width = warpSize) {
719+ return cinn::common::float16 (__shfl (static_cast <float >(var), srcLane, width));
720+ }
721+
722+ __device__ inline cinn::common::float16 __shfl_up (cinn::common::float16 var,
723+ unsigned int delta,
724+ int width = warpSize) {
725+ return cinn::common::float16 (
726+ __shfl_up (static_cast <float >(var), delta, width));
727+ }
728+
729+ __device__ inline cinn::common::float16 __shfl_down (cinn::common::float16 var,
730+ unsigned int delta,
731+ int width = warpSize) {
732+ return cinn::common::float16 (
733+ __shfl_down (static_cast <float >(var), delta, width));
734+ }
735+
736+ __device__ inline cinn::common::float16 __shfl_xor (cinn::common::float16 var,
737+ int laneMask,
738+ int width = warpSize) {
739+ return cinn::common::float16 (
740+ __shfl_xor (static_cast <float >(var), laneMask, width));
741+ }
742+
743+ __host__ __device__ inline cinn::common::float16 max (
744+ const cinn::common::float16& a, const cinn::common::float16& b) {
745+ return a > b ? a : b;
746+ }
747+
748+ __host__ __device__ inline cinn::common::float16 min (
749+ const cinn::common::float16& a, const cinn::common::float16& b) {
750+ return a < b ? a : b;
751+ }
752+ #endif // CINN_HIP_FP16
753+
673754#endif // CINN_COMMON_FLOAT16_H
0 commit comments