Skip to content

Commit 6aa7951

Browse files
committed
Merge branch 'cuda-fa-vec-fix-overflow-2' into concedo_experimental
2 parents eda4a31 + b13fcf8 commit 6aa7951

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
563563
acc += v.y*u.y;
564564
}
565565

566-
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
567566
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
567+
#define V_DOT2_F32_F16_AVAILABLE
568+
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
569+
570+
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
571+
#ifdef V_DOT2_F32_F16_AVAILABLE
568572
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
569573
#else
570574
#ifdef FAST_FP16_AVAILABLE
@@ -576,7 +580,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
576580
acc += tmpv.x * tmpu.x;
577581
acc += tmpv.y * tmpu.y;
578582
#endif // FAST_FP16_AVAILABLE
579-
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
583+
#endif // V_DOT2_F32_F16_AVAILABLE
580584
}
581585

582586
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
5555
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
5656
#pragma unroll
5757
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
58-
#ifdef FAST_FP16_AVAILABLE
58+
#ifdef V_DOT2_F32_F16_AVAILABLE
5959
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
6060
#else
6161
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
62-
#endif // FP16_AVAILABLE
62+
#endif // V_DOT2_F32_F16_AVAILABLE
6363
}
6464
}
6565

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
8686

8787
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
8888
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
89-
#ifdef FAST_FP16_AVAILABLE
89+
#ifdef V_DOT2_F32_F16_AVAILABLE
9090
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
9191
#else
9292
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
93-
#endif // FAST_FP16_AVAILABLE
93+
#endif // V_DOT2_F32_F16_AVAILABLE
9494

9595
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
9696

@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
112112

113113
constexpr int ne_KQ = ncols*D;
114114
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
115-
#ifdef FAST_FP16_AVAILABLE
115+
#ifdef V_DOT2_F32_F16_AVAILABLE
116116
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
117117
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
118118
#else
119119
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
120120
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
121-
#endif // FAST_FP16_AVAILABLE
121+
#endif // V_DOT2_F32_F16_AVAILABLE
122122

123123
float KQ_max[ncols];
124124
float KQ_sum[ncols];
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
129129
}
130130

131131
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
132-
#ifdef FAST_FP16_AVAILABLE
132+
#ifdef V_DOT2_F32_F16_AVAILABLE
133133
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
134134
#else
135135
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
136-
#endif // FAST_FP16_AVAILABLE
136+
#endif // V_DOT2_F32_F16_AVAILABLE
137137
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
138138
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
139139
if constexpr (Q_q8_1) {
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
191191

192192
__syncthreads();
193193
} else {
194-
#ifdef FAST_FP16_AVAILABLE
194+
#ifdef V_DOT2_F32_F16_AVAILABLE
195195
const half2 scale_h2 = make_half2(scale, scale);
196196
#pragma unroll
197197
for (int j = 0; j < ncols; ++j) {
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
233233
Q_reg[j][k].y *= scale;
234234
}
235235
}
236-
#endif // FAST_FP16_AVAILABLE
236+
#endif // V_DOT2_F32_F16_AVAILABLE
237237
}
238238

239239
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
291291
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
292292
KQ[j*nthreads + tid] = KQ_reg[j];
293293

294-
#ifdef FAST_FP16_AVAILABLE
294+
#ifdef V_DOT2_F32_F16_AVAILABLE
295295
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
296296
#pragma unroll
297297
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
303303
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
304304
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
305305
}
306-
#endif // FAST_FP16_AVAILABLE
306+
#endif // V_DOT2_F32_F16_AVAILABLE
307307
}
308308

309309
#ifndef GGML_USE_HIP
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
314314
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
315315
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
316316

317-
#ifdef FAST_FP16_AVAILABLE
317+
#ifdef V_DOT2_F32_F16_AVAILABLE
318318
half2 KQ_k[ncols];
319319
#pragma unroll
320320
for (int j = 0; j < ncols; ++j) {
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
353353
}
354354
}
355355
}
356-
#endif // FAST_FP16_AVAILABLE
356+
#endif // V_DOT2_F32_F16_AVAILABLE
357357
}
358358
}
359359

@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
374374

375375
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
376376

377-
#ifdef FAST_FP16_AVAILABLE
377+
#ifdef V_DOT2_F32_F16_AVAILABLE
378378
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
379379
#pragma unroll
380380
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
386386
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
387387
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
388388
}
389-
#endif // FAST_FP16_AVAILABLE
389+
#endif // V_DOT2_F32_F16_AVAILABLE
390390
}
391391
}
392392

@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
421421
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
422422
KQ_max[j_VKQ] = kqmax_new;
423423

424-
#ifdef FAST_FP16_AVAILABLE
424+
#ifdef V_DOT2_F32_F16_AVAILABLE
425425
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
426426
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
427427

@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
452452
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
453453
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
454454
}
455-
#endif // FAST_FP16_AVAILABLE
455+
#endif // V_DOT2_F32_F16_AVAILABLE
456456

457457
KQ_sum[j_VKQ] *= kqmax_scale;
458458
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);

0 commit comments

Comments
 (0)