@@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16(
15
15
const char * __restrict__ K,
16
16
const char * __restrict__ V,
17
17
const char * __restrict__ mask,
18
- float * __restrict__ dst,
19
- half2 * __restrict__ dst_meta,
18
+ float * __restrict__ dst,
19
+ float2 * __restrict__ dst_meta,
20
20
const float scale,
21
21
const int ne00,
22
22
const int ne01,
@@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16(
180
180
if (parallel_blocks == 1 || tid != 0 ) {
181
181
return ;
182
182
}
183
- dst_meta[ic*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_half2 (kqmax, kqsum);
183
+ dst_meta[ic*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax, kqsum);
184
184
#else
185
185
NO_DEVICE_CODE;
186
186
#endif // FP16_AVAILABLE
@@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16(
194
194
const char * __restrict__ K,
195
195
const char * __restrict__ V,
196
196
const char * __restrict__ mask,
197
- float * __restrict__ dst,
198
- half2 * __restrict__ dst_meta,
197
+ float * __restrict__ dst,
198
+ float2 * __restrict__ dst_meta,
199
199
const float scale,
200
200
const int ne00,
201
201
const int ne01,
@@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16(
555
555
continue ;
556
556
}
557
557
558
- half2 dst_meta_val;
558
+ float2 dst_meta_val;
559
559
if (std::is_same<KQ_acc_t, float >::value) {
560
- reinterpret_cast <half&>( dst_meta_val.x ) = KQ_max_f[j0 /nwarps];
560
+ dst_meta_val.x = KQ_max_f[j0 /nwarps];
561
561
} else {
562
- dst_meta_val = KQ_max_h2[j0 /nwarps];
562
+ dst_meta_val. x = __low2float ( KQ_max_h2[j0 /nwarps]) ;
563
563
}
564
- reinterpret_cast <half&>( dst_meta_val.y ) = KQ_rowsum_j;
564
+ dst_meta_val.y = KQ_rowsum_j;
565
565
dst_meta[(ic0 + j_VKQ)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = dst_meta_val;
566
566
}
567
567
#else
@@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16(
572
572
template <int D, int parallel_blocks> // D == head size
573
573
__launch_bounds__ (D, 1 )
574
574
static __global__ void flash_attn_combine_results(
575
- const float * __restrict__ VKQ_parts,
576
- const half2 * __restrict__ VKQ_meta,
575
+ const float * __restrict__ VKQ_parts,
576
+ const float2 * __restrict__ VKQ_meta,
577
577
float * __restrict__ dst) {
578
578
#if FP16_AVAILABLE
579
579
VKQ_parts += parallel_blocks*D * gridDim .y *blockIdx .x ;
@@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results(
583
583
const int tid = threadIdx .x ;
584
584
__builtin_assume (tid < D);
585
585
586
- __shared__ half2 meta[parallel_blocks];
587
- if (tid < parallel_blocks) {
588
- meta[threadIdx .x ] = VKQ_meta[blockIdx .y *parallel_blocks + tid];
586
+ __shared__ float2 meta[parallel_blocks];
587
+ if (tid < 2 * parallel_blocks) {
588
+ (( float *) meta) [threadIdx .x ] = (( const float *) VKQ_meta) [blockIdx .y *( 2 * parallel_blocks) + tid];
589
589
}
590
590
591
591
__syncthreads ();
592
592
593
- half kqmax = __low2half ( meta[0 ]) ;
593
+ float kqmax = meta[0 ]. x ;
594
594
#pragma unroll
595
595
for (int l = 1 ; l < parallel_blocks; ++l) {
596
- kqmax = __hmax (kqmax, __low2half ( meta[l]) );
596
+ kqmax = max (kqmax, meta[l]. x );
597
597
}
598
598
599
599
float VKQ_numerator = 0 .0f ;
600
600
float VKQ_denominator = 0 .0f ;
601
601
#pragma unroll
602
602
for (int l = 0 ; l < parallel_blocks; ++l) {
603
- const half diff = __low2half ( meta[l]) - kqmax;
604
- float KQ_max_scale = hexp (diff);
605
- const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half ( SOFTMAX_FTZ_THRESHOLD) );
603
+ const float diff = meta[l]. x - kqmax;
604
+ const float KQ_max_scale = expf (diff);
605
+ const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
606
606
*((uint *) &KQ_max_scale) &= ftz_mask;
607
607
608
608
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .y *D + blockIdx .y *D + tid];
609
- VKQ_denominator += KQ_max_scale * __high2float ( meta[l]) ;
609
+ VKQ_denominator += KQ_max_scale * meta[l]. y ;
610
610
}
611
611
612
612
dst[blockIdx .y *D + tid] = VKQ_numerator / VKQ_denominator;
@@ -643,8 +643,8 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
643
643
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
644
644
ggml_cuda_pool & pool, cudaStream_t main_stream
645
645
) {
646
- ggml_cuda_pool_alloc<float > dst_tmp (pool);
647
- ggml_cuda_pool_alloc<half2 > dst_tmp_meta (pool);
646
+ ggml_cuda_pool_alloc<float > dst_tmp (pool);
647
+ ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
648
648
649
649
if (parallel_blocks > 1 ) {
650
650
dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
@@ -694,8 +694,8 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
694
694
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
695
695
ggml_cuda_pool & pool, cudaStream_t main_stream
696
696
) {
697
- ggml_cuda_pool_alloc<float > dst_tmp (pool);
698
- ggml_cuda_pool_alloc<half2 > dst_tmp_meta (pool);
697
+ ggml_cuda_pool_alloc<float > dst_tmp (pool);
698
+ ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
699
699
700
700
if (parallel_blocks > 1 ) {
701
701
dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
0 commit comments