Skip to content

Commit 87968de

Browse files
fix KQ FP32 precision fpr parallel_blocks > 1
1 parent 2f538b9 commit 87968de

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

ggml-cuda/fattn.cu

+24-24
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16(
1515
const char * __restrict__ K,
1616
const char * __restrict__ V,
1717
const char * __restrict__ mask,
18-
float * __restrict__ dst,
19-
half2 * __restrict__ dst_meta,
18+
float * __restrict__ dst,
19+
float2 * __restrict__ dst_meta,
2020
const float scale,
2121
const int ne00,
2222
const int ne01,
@@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16(
180180
if (parallel_blocks == 1 || tid != 0) {
181181
return;
182182
}
183-
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum);
183+
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
184184
#else
185185
NO_DEVICE_CODE;
186186
#endif // FP16_AVAILABLE
@@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16(
194194
const char * __restrict__ K,
195195
const char * __restrict__ V,
196196
const char * __restrict__ mask,
197-
float * __restrict__ dst,
198-
half2 * __restrict__ dst_meta,
197+
float * __restrict__ dst,
198+
float2 * __restrict__ dst_meta,
199199
const float scale,
200200
const int ne00,
201201
const int ne01,
@@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16(
555555
continue;
556556
}
557557

558-
half2 dst_meta_val;
558+
float2 dst_meta_val;
559559
if (std::is_same<KQ_acc_t, float>::value) {
560-
reinterpret_cast<half&>(dst_meta_val.x) = KQ_max_f[j0/nwarps];
560+
dst_meta_val.x = KQ_max_f[j0/nwarps];
561561
} else {
562-
dst_meta_val = KQ_max_h2[j0/nwarps];
562+
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
563563
}
564-
reinterpret_cast<half&>(dst_meta_val.y) = KQ_rowsum_j;
564+
dst_meta_val.y = KQ_rowsum_j;
565565
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
566566
}
567567
#else
@@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16(
572572
template<int D, int parallel_blocks> // D == head size
573573
__launch_bounds__(D, 1)
574574
static __global__ void flash_attn_combine_results(
575-
const float * __restrict__ VKQ_parts,
576-
const half2 * __restrict__ VKQ_meta,
575+
const float * __restrict__ VKQ_parts,
576+
const float2 * __restrict__ VKQ_meta,
577577
float * __restrict__ dst) {
578578
#if FP16_AVAILABLE
579579
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
@@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results(
583583
const int tid = threadIdx.x;
584584
__builtin_assume(tid < D);
585585

586-
__shared__ half2 meta[parallel_blocks];
587-
if (tid < parallel_blocks) {
588-
meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid];
586+
__shared__ float2 meta[parallel_blocks];
587+
if (tid < 2*parallel_blocks) {
588+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
589589
}
590590

591591
__syncthreads();
592592

593-
half kqmax = __low2half(meta[0]);
593+
float kqmax = meta[0].x;
594594
#pragma unroll
595595
for (int l = 1; l < parallel_blocks; ++l) {
596-
kqmax = __hmax(kqmax, __low2half(meta[l]));
596+
kqmax = max(kqmax, meta[l].x);
597597
}
598598

599599
float VKQ_numerator = 0.0f;
600600
float VKQ_denominator = 0.0f;
601601
#pragma unroll
602602
for (int l = 0; l < parallel_blocks; ++l) {
603-
const half diff = __low2half(meta[l]) - kqmax;
604-
float KQ_max_scale = hexp(diff);
605-
const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD));
603+
const float diff = meta[l].x - kqmax;
604+
const float KQ_max_scale = expf(diff);
605+
const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
606606
*((uint *) &KQ_max_scale) &= ftz_mask;
607607

608608
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
609-
VKQ_denominator += KQ_max_scale * __high2float(meta[l]);
609+
VKQ_denominator += KQ_max_scale * meta[l].y;
610610
}
611611

612612
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
@@ -643,8 +643,8 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
643643
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
644644
ggml_cuda_pool & pool, cudaStream_t main_stream
645645
) {
646-
ggml_cuda_pool_alloc<float> dst_tmp(pool);
647-
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
646+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
647+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
648648

649649
if (parallel_blocks > 1) {
650650
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -694,8 +694,8 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
694694
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
695695
ggml_cuda_pool & pool, cudaStream_t main_stream
696696
) {
697-
ggml_cuda_pool_alloc<float> dst_tmp(pool);
698-
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
697+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
698+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
699699

700700
if (parallel_blocks > 1) {
701701
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

0 commit comments

Comments
 (0)