@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
8686
8787 constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
8888 constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
89- #ifdef FAST_FP16_AVAILABLE
89+ #ifdef V_DOT2_F32_F16_AVAILABLE
9090 constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
9191#else
9292 constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float , V_rows_per_thread>();
93- #endif // FAST_FP16_AVAILABLE
93+ #endif // V_DOT2_F32_F16_AVAILABLE
9494
9595 const int ic0 = blockIdx .x * ncols; // Index of the Q/QKV column to work on.
9696
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
112112
113113 constexpr int ne_KQ = ncols*D;
114114 constexpr int ne_combine = nwarps*V_cols_per_iter*D;
115- #ifdef FAST_FP16_AVAILABLE
115+ #ifdef V_DOT2_F32_F16_AVAILABLE
116116 half2 VKQ[ncols][(D/2 )/nthreads_V] = {{{0 .0f , 0 .0f }}};
117117 __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
118118#else
119119 float2 VKQ[ncols][(D/2 )/nthreads_V] = {{{0 .0f , 0 .0f }}};
120120 __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
121- #endif // FAST_FP16_AVAILABLE
121+ #endif // V_DOT2_F32_F16_AVAILABLE
122122
123123 float KQ_max[ncols];
124124 float KQ_sum[ncols];
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
129129 }
130130
131131 // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
132- #ifdef FAST_FP16_AVAILABLE
132+ #ifdef V_DOT2_F32_F16_AVAILABLE
133133 half2 Q_reg[ncols][(D/2 )/nthreads_KQ]; // Will be initialized completely.
134134#else
135135 float2 Q_reg[ncols][(D/2 )/nthreads_KQ] = {{{0 .0f , 0 .0f }}}; // May be only partially initialized.
136- #endif // FAST_FP16_AVAILABLE
136+ #endif // V_DOT2_F32_F16_AVAILABLE
137137 int Q_i32[ncols][1 > D/(sizeof (int )*nthreads_KQ) ? 1 : D/(sizeof (int )*nthreads_KQ)];
138138 float2 Q_ds[ncols][1 > D/(sizeof (int )*nthreads_KQ) ? 1 : D/(sizeof (int )*nthreads_KQ)];
139139 if constexpr (Q_q8_1) {
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
191191
192192 __syncthreads ();
193193 } else {
194- #ifdef FAST_FP16_AVAILABLE
194+ #ifdef V_DOT2_F32_F16_AVAILABLE
195195 const half2 scale_h2 = make_half2 (scale, scale);
196196#pragma unroll
197197 for (int j = 0 ; j < ncols; ++j) {
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
233233 Q_reg[j][k].y *= scale;
234234 }
235235 }
236- #endif // FAST_FP16_AVAILABLE
236+ #endif // V_DOT2_F32_F16_AVAILABLE
237237 }
238238
239239 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim .x + blockIdx .x ] : ne11;
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
291291 KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
292292 KQ[j*nthreads + tid] = KQ_reg[j];
293293
294- #ifdef FAST_FP16_AVAILABLE
294+ #ifdef V_DOT2_F32_F16_AVAILABLE
295295 const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
296296#pragma unroll
297297 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V) {
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
303303 VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
304304 VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
305305 }
306- #endif // FAST_FP16_AVAILABLE
306+ #endif // V_DOT2_F32_F16_AVAILABLE
307307 }
308308
309309#ifndef GGML_USE_HIP
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
314314 for (int k0 = 0 ; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
315315 const int k = threadIdx .y *WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx .x / nthreads_V);
316316
317- #ifdef FAST_FP16_AVAILABLE
317+ #ifdef V_DOT2_F32_F16_AVAILABLE
318318 half2 KQ_k[ncols];
319319#pragma unroll
320320 for (int j = 0 ; j < ncols; ++j) {
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
353353 }
354354 }
355355 }
356- #endif // FAST_FP16_AVAILABLE
356+ #endif // V_DOT2_F32_F16_AVAILABLE
357357 }
358358 }
359359
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
374374
375375 KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx .x == 0 ? expf (sink - KQ_max[j]) : 0 .0f );
376376
377- #ifdef FAST_FP16_AVAILABLE
377+ #ifdef V_DOT2_F32_F16_AVAILABLE
378378 const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
379379#pragma unroll
380380 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V) {
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
386386 VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
387387 VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
388388 }
389- #endif // FAST_FP16_AVAILABLE
389+ #endif // V_DOT2_F32_F16_AVAILABLE
390390 }
391391 }
392392
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
421421 const float kqmax_scale = expf (KQ_max[j_VKQ] - kqmax_new);
422422 KQ_max[j_VKQ] = kqmax_new;
423423
424- #ifdef FAST_FP16_AVAILABLE
424+ #ifdef V_DOT2_F32_F16_AVAILABLE
425425 half2 * VKQ_tmp = (half2 *) KQ + threadIdx .y *(V_cols_per_iter*D/2 )
426426 + (nthreads_V == WARP_SIZE ? 0 : threadIdx .x / nthreads_V)*(D/2 );
427427
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
452452 ggml_cuda_memcpy_1<V_rows_per_thread/2 *sizeof (float )>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
453453 ggml_cuda_memcpy_1<V_rows_per_thread/2 *sizeof (float )>(VKQ_tmp + i_VKQ + V_rows_per_thread/4 , &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4 ]);
454454 }
455- #endif // FAST_FP16_AVAILABLE
455+ #endif // V_DOT2_F32_F16_AVAILABLE
456456
457457 KQ_sum[j_VKQ] *= kqmax_scale;
458458 KQ_sum[j_VKQ] = warp_reduce_sum (KQ_sum[j_VKQ]);
0 commit comments