@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
3184
3184
threadgroup_barrier (mem_flags::mem_threadgroup);
3185
3185
3186
3186
{
3187
- half S[Q] = { [0 ... Q-1 ] = 0 .0f };
3188
- half M[Q] = { [0 ... Q-1 ] = -__FLT16_MAX__/2 };
3187
+ float S[Q] = { [0 ... Q-1 ] = 0 .0f };
3188
+ float M[Q] = { [0 ... Q-1 ] = -__FLT16_MAX__/2 };
3189
3189
3190
3190
// thread indices inside the simdgroup
3191
3191
// TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
3202
3202
3203
3203
const bool has_mask = mask != q;
3204
3204
3205
- half slope = 1 .0f ;
3205
+ float slope = 1 .0f ;
3206
3206
3207
3207
// ALiBi
3208
3208
if (args.max_bias > 0 .0f ) {
3209
3209
const short h = iq2;
3210
3210
3211
- const half base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3211
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3212
3212
const short exph = h < args.n_head_log2 ? h + 1 : 2 *(h - args.n_head_log2 ) + 1 ;
3213
3213
3214
3214
slope = pow (base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
3224
3224
3225
3225
if (has_mask) {
3226
3226
// used to detect blocks full of -INF
3227
- half smax = -INFINITY;
3227
+ float smax = -INFINITY;
3228
3228
3229
3229
// load the mask in shared memory
3230
3230
#pragma unroll(Q)
3231
3231
for (short j = 0 ; j < Q; ++j) {
3232
3232
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 );
3233
3233
3234
- const half m = pm[ic + tiisg];
3234
+ const float m = pm[ic + tiisg];
3235
3235
3236
3236
ss[j*TS + C + tiisg] = m;
3237
3237
smax = max (smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
3327
3327
// online softmax
3328
3328
{
3329
3329
for (ushort j = 0 ; j < Q; ++j) {
3330
- const half m = M[j];
3330
+ const float m = M[j];
3331
3331
3332
3332
// scale and apply the logitcap / mask
3333
- half s = ss[j*TS + tiisg]*args.scale ;
3333
+ float s = ss[j*TS + tiisg]*args.scale ;
3334
3334
3335
3335
if (args.logit_softcap != 0 .0f ) {
3336
3336
s = args.logit_softcap *precise::tanh (s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
3341
3341
3342
3342
M[j] = simd_max (max (M[j], s));
3343
3343
3344
- const half ms = exp (m - M[j]);
3345
- const half vs = exp (s - M[j]);
3344
+ const float ms = exp (m - M[j]);
3345
+ const float vs = exp (s - M[j]);
3346
3346
3347
3347
S[j] = S[j]*ms + simd_sum (vs);
3348
3348
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
3444
3444
3445
3445
// reduce the warps sequentially
3446
3446
for (ushort sg = 1 ; sg < nsg; ++sg) {
3447
- half S = { 0 .0f };
3448
- half M = { -__FLT16_MAX__/2 };
3447
+ float S = { 0 .0f };
3448
+ float M = { -__FLT16_MAX__/2 };
3449
3449
3450
3450
threadgroup_barrier (mem_flags::mem_threadgroup);
3451
3451
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
3461
3461
// the first simdgroup accumulates the results from the other simdgroups
3462
3462
if (sgitg == 0 ) {
3463
3463
for (short j = 0 ; j < Q; ++j) {
3464
- const half S0 = ss[j*TS + 0 ];
3465
- const half S1 = ss[j*TS + sg*SH + 0 ];
3464
+ const float S0 = ss[j*TS + 0 ];
3465
+ const float S1 = ss[j*TS + sg*SH + 0 ];
3466
3466
3467
- const half M0 = ss[j*TS + 1 ];
3468
- const half M1 = ss[j*TS + sg*SH + 1 ];
3467
+ const float M0 = ss[j*TS + 1 ];
3468
+ const float M1 = ss[j*TS + sg*SH + 1 ];
3469
3469
3470
3470
M = max (M0, M1);
3471
3471
3472
- const half ms0 = exp (M0 - M);
3473
- const half ms1 = exp (M1 - M);
3472
+ const float ms0 = exp (M0 - M);
3473
+ const float ms1 = exp (M1 - M);
3474
3474
3475
3475
S = S0*ms0 + S1*ms1;
3476
3476
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
3646
3646
constexpr short DV4 = DV/4 ;
3647
3647
constexpr short NW = N_SIMDWIDTH;
3648
3648
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649
- constexpr short SH = 2 *C; // shared memory per simdgroup
3649
+ constexpr short SH = 4 *C; // shared memory per simdgroup
3650
3650
3651
3651
const short T = DK + nsg*SH; // shared memory size per query in (half)
3652
3652
3653
- // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3653
+ // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 * C + Q*DK); // scratch buffer for mask
3658
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3659
3659
3660
3660
// store the result for all queries in local memory (the O matrix from the paper)
3661
3661
o4_t lo[DV4/NL];
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
3684
3684
threadgroup_barrier (mem_flags::mem_threadgroup);
3685
3685
3686
3686
{
3687
- half S = 0 .0f ;
3688
- half M = -__FLT16_MAX__/2 ;
3687
+ float S = 0 .0f ;
3688
+ float M = -__FLT16_MAX__/2 ;
3689
3689
3690
3690
// thread indices inside the simdgroup
3691
3691
const short tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
3703
3703
// pointer to the mask
3704
3704
device const half * pm = (device const half *) (mask + iq1*args.nb31 );
3705
3705
3706
- half slope = 1 .0f ;
3706
+ float slope = 1 .0f ;
3707
3707
3708
3708
// ALiBi
3709
3709
if (args.max_bias > 0 .0f ) {
3710
3710
const short h = iq2;
3711
3711
3712
- const half base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3712
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3713
3713
const short exph = h < args.n_head_log2 ? h + 1 : 2 *(h - args.n_head_log2 ) + 1 ;
3714
3714
3715
3715
slope = pow (base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
3799
3799
3800
3800
// online softmax
3801
3801
{
3802
- const half m = M;
3803
- const half s = ss[tiisg];
3802
+ const float m = M;
3803
+ const float s = ss[tiisg];
3804
3804
3805
3805
M = simd_max (max (M, s));
3806
3806
3807
- const half ms = exp (m - M);
3808
- const half vs = exp (s - M);
3807
+ const float ms = exp (m - M);
3808
+ const float vs = exp (s - M);
3809
3809
3810
3810
S = S*ms + simd_sum (vs);
3811
3811
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
3836
3836
v4_t mv;
3837
3837
deq_v_t4 (pv4 + i/nl_v, i%nl_v, mv);
3838
3838
3839
- lo[ii/NL] += mv*ms ;
3839
+ lo[ii/NL] += o4_t ( float4 (mv)* float4 (ms)) ;
3840
3840
}
3841
3841
}
3842
3842
}
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
3907
3907
// parallel reduce
3908
3908
for (short r = nsg/2 ; r > 0 ; r >>= 1 ) {
3909
3909
if (sgitg < r) {
3910
- const half S0 = ss[ 0 ];
3911
- const half S1 = ss[r*SH + 0 ];
3910
+ const float S0 = ss[ 0 ];
3911
+ const float S1 = ss[r*(SH/ 2 ) + 0 ];
3912
3912
3913
- const half M0 = ss[ 1 ];
3914
- const half M1 = ss[r*SH + 1 ];
3913
+ const float M0 = ss[ 1 ];
3914
+ const float M1 = ss[r*(SH/ 2 ) + 1 ];
3915
3915
3916
- const half M = max (M0, M1);
3916
+ const float M = max (M0, M1);
3917
3917
3918
- const half ms0 = exp (M0 - M);
3919
- const half ms1 = exp (M1 - M);
3918
+ const float ms0 = exp (M0 - M);
3919
+ const float ms1 = exp (M1 - M);
3920
3920
3921
- const half S = S0*ms0 + S1*ms1;
3921
+ const float S = S0*ms0 + S1*ms1;
3922
3922
3923
3923
if (tiisg == 0 ) {
3924
3924
ss[0 ] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
3950
3950
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3951
3951
//
3952
3952
#define FA_TYPES \
3953
- half4, \
3954
- half4, \
3955
- half4, \
3956
- float , \
3957
- half, half4 , \
3953
+ half4, \
3954
+ half4, \
3955
+ half4, \
3956
+ float , \
3957
+ float , float4 , \
3958
3958
half4
3959
3959
3960
3960
typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
0 commit comments