Skip to content

Commit 5b995f7

Browse files
authored
[CINN] Update cinn/runtime/cuda/float16.h (#75090)
1 parent 382f7f0 commit 5b995f7

File tree

1 file changed

+114
-33
lines changed

1 file changed

+114
-33
lines changed

paddle/cinn/runtime/cuda/float16.h

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
#endif // __CUDACC__
4141
#endif // CINN_WITH_CUDA
4242

43+
#ifdef CINN_WITH_HIP
44+
#include <hip/hip_runtime.h>
45+
#if defined(__HIPCC__)
46+
#define __HIP_PLATFORM_AMD__
47+
#include <hip/hip_fp16.h>
48+
#define CINN_HIP_FP16
49+
#endif
50+
#endif
51+
4352
#ifdef __cplusplus
4453
#ifndef _WIN32
4554
#define CINN_ALIGN(x) __attribute__((aligned(x)))
@@ -83,9 +92,9 @@ struct CINN_ALIGN(2) float16 {
8392
~float16() = default;
8493

8594
// Constructors
86-
#ifdef CINN_CUDA_FP16
95+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
8796
__host__ __device__ inline explicit float16(const half& h) {
88-
#if (CUDA_VERSION >= 9000)
97+
#if defined(CINN_CUDA_FP16) && (CUDA_VERSION >= 9000) || defined(CINN_HIP_FP16)
8998
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
9099
#else
91100
x = h.x;
@@ -94,7 +103,9 @@ struct CINN_ALIGN(2) float16 {
94103
#endif // CINN_CUDA_FP16
95104

96105
__host__ __device__ inline explicit float16(float val) {
97-
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
106+
#if defined(CINN_CUDA_FP16) && \
107+
(defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
108+
defined(CINN_HIP_FP16)
98109
half tmp = __float2half(val);
99110
x = *reinterpret_cast<uint16_t*>(&tmp);
100111

@@ -129,9 +140,9 @@ struct CINN_ALIGN(2) float16 {
129140
: x(float16(static_cast<float>(val)).x) {}
130141

131142
// Assignment operators
132-
#ifdef CINN_CUDA_FP16
143+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
133144
__host__ __device__ inline float16& operator=(const half& rhs) {
134-
#if CUDA_VERSION >= 9000
145+
#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
135146
x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
136147
#else
137148
x = rhs.x;
@@ -196,9 +207,9 @@ struct CINN_ALIGN(2) float16 {
196207
}
197208

198209
// Conversion operators
199-
#ifdef CINN_CUDA_FP16
210+
#if defined(CINN_CUDA_FP16) || defined(CINN_HIP_FP16)
200211
__host__ __device__ inline half to_half() const {
201-
#if CUDA_VERSION >= 9000
212+
#if CUDA_VERSION >= 9000 || defined(CINN_HIP_FP16)
202213
__half_raw h;
203214
h.x = x;
204215
return half(h);
@@ -211,7 +222,9 @@ struct CINN_ALIGN(2) float16 {
211222
#endif // CINN_CUDA_FP16
212223

213224
__host__ __device__ inline operator float() const {
214-
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
225+
#if defined(CINN_CUDA_FP16) && \
226+
(defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
227+
defined(CINN_HIP_FP16)
215228
half tmp = *reinterpret_cast<const half*>(this);
216229
return __half2float(tmp);
217230

@@ -344,9 +357,9 @@ struct CINN_ALIGN(4) float162 {
344357
// CUDA 9.0 regarding the half data type.
345358
// ROCM has built-in arithmetic operators as not defined
346359
// __HIP_NO_HALF_OPERATORS__
347-
#if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
360+
#if (defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000) || defined(CINN_HIP_FP16)
348361
__device__ inline half operator+(const half& a, const half& b) {
349-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
362+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
350363
return __hadd(a, b);
351364
#else
352365
float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
@@ -355,7 +368,7 @@ __device__ inline half operator+(const half& a, const half& b) {
355368
}
356369

357370
__device__ inline half operator-(const half& a, const half& b) {
358-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
371+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
359372
return __hsub(a, b);
360373
#else
361374
float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
@@ -364,7 +377,7 @@ __device__ inline half operator-(const half& a, const half& b) {
364377
}
365378

366379
__device__ inline half operator*(const half& a, const half& b) {
367-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
380+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
368381
return __hmul(a, b);
369382
#else
370383
float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
@@ -373,7 +386,7 @@ __device__ inline half operator*(const half& a, const half& b) {
373386
}
374387

375388
__device__ inline half operator/(const half& a, const half& b) {
376-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
389+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
377390
float num = __half2float(a);
378391
float denom = __half2float(b);
379392
return __float2half(num / denom);
@@ -384,14 +397,15 @@ __device__ inline half operator/(const half& a, const half& b) {
384397
}
385398

386399
__device__ inline half operator-(const half& a) {
387-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
400+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
388401
return __hneg(a);
389402
#else
390403
float res = -static_cast<float>(float16(a));
391404
return float16(res).to_half();
392405
#endif
393406
}
394407

408+
#ifndef CINN_WITH_HIP
395409
__device__ inline half& operator+=(half& a, const half& b) { // NOLINT
396410
a = a + b;
397411
return a;
@@ -411,49 +425,50 @@ __device__ inline half& operator/=(half& a, const half& b) { // NOLINT
411425
a = a / b;
412426
return a;
413427
}
428+
#endif
414429

415430
__device__ inline bool operator==(const half& a, const half& b) {
416-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
431+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
417432
return __heq(a, b);
418433
#else
419434
return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
420435
#endif
421436
}
422437

423438
__device__ inline bool operator!=(const half& a, const half& b) {
424-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
439+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
425440
return __hne(a, b);
426441
#else
427442
return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
428443
#endif
429444
}
430445

431446
__device__ inline bool operator<(const half& a, const half& b) {
432-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
447+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
433448
return __hlt(a, b);
434449
#else
435450
return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
436451
#endif
437452
}
438453

439454
__device__ inline bool operator<=(const half& a, const half& b) {
440-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
455+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
441456
return __hle(a, b);
442457
#else
443458
return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
444459
#endif
445460
}
446461

447462
__device__ inline bool operator>(const half& a, const half& b) {
448-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
463+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
449464
return __hgt(a, b);
450465
#else
451466
return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
452467
#endif
453468
}
454469

455470
__device__ inline bool operator>=(const half& a, const half& b) {
456-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
471+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(CINN_HIP_FP16)
457472
return __hge(a, b);
458473
#else
459474
return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
@@ -465,7 +480,9 @@ __device__ inline bool operator>=(const half& a, const half& b) {
465480
// Arithmetic operators for float16 on GPU
466481
__host__ __device__ inline float16 operator+(const float16& a,
467482
const float16& b) {
468-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
483+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
484+
__CUDA_ARCH__ >= 530) || \
485+
defined(CINN_HIP_FP16)
469486
return float16(__hadd(a.to_half(), b.to_half()));
470487
#else
471488
return float16(static_cast<float>(a) + static_cast<float>(b));
@@ -474,7 +491,9 @@ __host__ __device__ inline float16 operator+(const float16& a,
474491

475492
__host__ __device__ inline float16 operator-(const float16& a,
476493
const float16& b) {
477-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
494+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
495+
__CUDA_ARCH__ >= 530) || \
496+
defined(CINN_HIP_FP16)
478497
return float16(__hsub(a.to_half(), b.to_half()));
479498
#else
480499
return float16(static_cast<float>(a) - static_cast<float>(b));
@@ -483,7 +502,9 @@ __host__ __device__ inline float16 operator-(const float16& a,
483502

484503
__host__ __device__ inline float16 operator*(const float16& a,
485504
const float16& b) {
486-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
505+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
506+
__CUDA_ARCH__ >= 530) || \
507+
defined(CINN_HIP_FP16)
487508
return float16(__hmul(a.to_half(), b.to_half()));
488509
#else
489510
return float16(static_cast<float>(a) * static_cast<float>(b));
@@ -492,7 +513,9 @@ __host__ __device__ inline float16 operator*(const float16& a,
492513

493514
__host__ __device__ inline float16 operator/(const float16& a,
494515
const float16& b) {
495-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
516+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
517+
__CUDA_ARCH__ >= 530) || \
518+
defined(CINN_HIP_FP16)
496519
// TODO(kexinzhao): check which cuda version starts to support __hdiv
497520
float num = __half2float(a.to_half());
498521
float denom = __half2float(b.to_half());
@@ -503,7 +526,9 @@ __host__ __device__ inline float16 operator/(const float16& a,
503526
}
504527

505528
__host__ __device__ inline float16 operator-(const float16& a) {
506-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
529+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
530+
__CUDA_ARCH__ >= 530) || \
531+
defined(CINN_HIP_FP16)
507532
return float16(__hneg(a.to_half()));
508533
#else
509534
float16 res;
@@ -537,47 +562,59 @@ __host__ __device__ inline float16& operator/=(float16& a, // NOLINT
537562
}
538563

539564
__host__ __device__ inline bool operator==(const float16& a, const float16& b) {
540-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
565+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
566+
__CUDA_ARCH__ >= 530) || \
567+
defined(CINN_HIP_FP16)
541568
return __heq(a.to_half(), b.to_half());
542569
#else
543570
return static_cast<float>(a) == static_cast<float>(b);
544571
#endif
545572
}
546573

547574
__host__ __device__ inline bool operator!=(const float16& a, const float16& b) {
548-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
575+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
576+
__CUDA_ARCH__ >= 530) || \
577+
defined(CINN_HIP_FP16)
549578
return __hne(a.to_half(), b.to_half());
550579
#else
551580
return static_cast<float>(a) != static_cast<float>(b);
552581
#endif
553582
}
554583

555584
__host__ __device__ inline bool operator<(const float16& a, const float16& b) {
556-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
585+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
586+
__CUDA_ARCH__ >= 530) || \
587+
defined(CINN_HIP_FP16)
557588
return __hlt(a.to_half(), b.to_half());
558589
#else
559590
return static_cast<float>(a) < static_cast<float>(b);
560591
#endif
561592
}
562593

563594
__host__ __device__ inline bool operator<=(const float16& a, const float16& b) {
564-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
596+
__CUDA_ARCH__ >= 530) || \
597+
defined(CINN_HIP_FP16)
565598
return __hle(a.to_half(), b.to_half());
566599
#else
567600
return static_cast<float>(a) <= static_cast<float>(b);
568601
#endif
569602
}
570603

571604
__host__ __device__ inline bool operator>(const float16& a, const float16& b) {
572-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
605+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
606+
__CUDA_ARCH__ >= 530) || \
607+
defined(CINN_HIP_FP16)
573608
return __hgt(a.to_half(), b.to_half());
574609
#else
575610
return static_cast<float>(a) > static_cast<float>(b);
576611
#endif
577612
}
578613

579614
__host__ __device__ inline bool operator>=(const float16& a, const float16& b) {
580-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
615+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
616+
__CUDA_ARCH__ >= 530) || \
617+
defined(CINN_HIP_FP16)
581618
return __hge(a.to_half(), b.to_half());
582619
#else
583620
return static_cast<float>(a) >= static_cast<float>(b);
@@ -592,7 +629,9 @@ __host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
592629
}
593630

594631
__host__ __device__ inline bool(isnan)(const float16& a) {
595-
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
632+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
633+
__CUDA_ARCH__ >= 530) || \
634+
defined(CINN_HIP_FP16)
596635
return __hisnan(a.to_half());
597636
#else
598637
return (a.x & 0x7fff) > 0x7c00;
@@ -608,7 +647,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) {
608647
}
609648

610649
__host__ __device__ inline float16(abs)(const float16& a) {
611-
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
650+
#if (defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && \
651+
__CUDA_ARCH__ >= 530) || \
652+
defined(CINN_HIP_FP16)
612653
return static_cast<float16>(__habs(a.to_half()));
613654
#else
614655
return static_cast<float16>(fabsf(static_cast<float>(a)));
@@ -670,4 +711,44 @@ __host__ __device__ inline cinn::common::float16 min(
670711
}
671712
#endif // __cplusplus && CINN_CUDA_FP16
672713

714+
// Note: HIP does not support half-float shuffles.
715+
#if defined(CINN_HIP_FP16)
716+
__device__ inline cinn::common::float16 __shfl(cinn::common::float16 var,
717+
int srcLane,
718+
int width = warpSize) {
719+
return cinn::common::float16(__shfl(static_cast<float>(var), srcLane, width));
720+
}
721+
722+
__device__ inline cinn::common::float16 __shfl_up(cinn::common::float16 var,
723+
unsigned int delta,
724+
int width = warpSize) {
725+
return cinn::common::float16(
726+
__shfl_up(static_cast<float>(var), delta, width));
727+
}
728+
729+
__device__ inline cinn::common::float16 __shfl_down(cinn::common::float16 var,
730+
unsigned int delta,
731+
int width = warpSize) {
732+
return cinn::common::float16(
733+
__shfl_down(static_cast<float>(var), delta, width));
734+
}
735+
736+
__device__ inline cinn::common::float16 __shfl_xor(cinn::common::float16 var,
737+
int laneMask,
738+
int width = warpSize) {
739+
return cinn::common::float16(
740+
__shfl_xor(static_cast<float>(var), laneMask, width));
741+
}
742+
743+
__host__ __device__ inline cinn::common::float16 max(
744+
const cinn::common::float16& a, const cinn::common::float16& b) {
745+
return a > b ? a : b;
746+
}
747+
748+
__host__ __device__ inline cinn::common::float16 min(
749+
const cinn::common::float16& a, const cinn::common::float16& b) {
750+
return a < b ? a : b;
751+
}
752+
#endif // CINN_HIP_FP16
753+
673754
#endif // CINN_COMMON_FLOAT16_H

0 commit comments

Comments
 (0)