Skip to content

Commit 3fd072a

Browse files
authored
metal : use F32 prec in FA kernels (#12688)
* metal : use F32 prec in FA kernels ggml-ci * cont : fix FA vec kernel ggml-ci
1 parent a6f32f0 commit 3fd072a

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

ggml/src/ggml-metal/ggml-metal.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
41794179
// ne00*(nsg)
41804180
// each simdgroup has a full f16 head vector in shared mem to accumulate results
41814181
//
4182-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4182+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
41834183

41844184
int64_t nsgmax = 2;
41854185
while (true) {

ggml/src/ggml-metal/ggml-metal.metal

+47-47
Original file line numberDiff line numberDiff line change
@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
31843184
threadgroup_barrier(mem_flags::mem_threadgroup);
31853185

31863186
{
3187-
half S[Q] = { [0 ... Q-1] = 0.0f };
3188-
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3187+
float S[Q] = { [0 ... Q-1] = 0.0f };
3188+
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
31893189

31903190
// thread indices inside the simdgroup
31913191
// TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
32023202

32033203
const bool has_mask = mask != q;
32043204

3205-
half slope = 1.0f;
3205+
float slope = 1.0f;
32063206

32073207
// ALiBi
32083208
if (args.max_bias > 0.0f) {
32093209
const short h = iq2;
32103210

3211-
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3211+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
32123212
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
32133213

32143214
slope = pow(base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
32243224

32253225
if (has_mask) {
32263226
// used to detect blocks full of -INF
3227-
half smax = -INFINITY;
3227+
float smax = -INFINITY;
32283228

32293229
// load the mask in shared memory
32303230
#pragma unroll(Q)
32313231
for (short j = 0; j < Q; ++j) {
32323232
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
32333233

3234-
const half m = pm[ic + tiisg];
3234+
const float m = pm[ic + tiisg];
32353235

32363236
ss[j*TS + C + tiisg] = m;
32373237
smax = max(smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
33273327
// online softmax
33283328
{
33293329
for (ushort j = 0; j < Q; ++j) {
3330-
const half m = M[j];
3330+
const float m = M[j];
33313331

33323332
// scale and apply the logitcap / mask
3333-
half s = ss[j*TS + tiisg]*args.scale;
3333+
float s = ss[j*TS + tiisg]*args.scale;
33343334

33353335
if (args.logit_softcap != 0.0f) {
33363336
s = args.logit_softcap*precise::tanh(s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
33413341

33423342
M[j] = simd_max(max(M[j], s));
33433343

3344-
const half ms = exp(m - M[j]);
3345-
const half vs = exp(s - M[j]);
3344+
const float ms = exp(m - M[j]);
3345+
const float vs = exp(s - M[j]);
33463346

33473347
S[j] = S[j]*ms + simd_sum(vs);
33483348

@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
34443444

34453445
// reduce the warps sequentially
34463446
for (ushort sg = 1; sg < nsg; ++sg) {
3447-
half S = { 0.0f };
3448-
half M = { -__FLT16_MAX__/2 };
3447+
float S = { 0.0f };
3448+
float M = { -__FLT16_MAX__/2 };
34493449

34503450
threadgroup_barrier(mem_flags::mem_threadgroup);
34513451

@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
34613461
// the first simdgroup accumulates the results from the other simdgroups
34623462
if (sgitg == 0) {
34633463
for (short j = 0; j < Q; ++j) {
3464-
const half S0 = ss[j*TS + 0];
3465-
const half S1 = ss[j*TS + sg*SH + 0];
3464+
const float S0 = ss[j*TS + 0];
3465+
const float S1 = ss[j*TS + sg*SH + 0];
34663466

3467-
const half M0 = ss[j*TS + 1];
3468-
const half M1 = ss[j*TS + sg*SH + 1];
3467+
const float M0 = ss[j*TS + 1];
3468+
const float M1 = ss[j*TS + sg*SH + 1];
34693469

34703470
M = max(M0, M1);
34713471

3472-
const half ms0 = exp(M0 - M);
3473-
const half ms1 = exp(M1 - M);
3472+
const float ms0 = exp(M0 - M);
3473+
const float ms1 = exp(M1 - M);
34743474

34753475
S = S0*ms0 + S1*ms1;
34763476

@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
36463646
constexpr short DV4 = DV/4;
36473647
constexpr short NW = N_SIMDWIDTH;
36483648
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649-
constexpr short SH = 2*C; // shared memory per simdgroup
3649+
constexpr short SH = 4*C; // shared memory per simdgroup
36503650

36513651
const short T = DK + nsg*SH; // shared memory size per query in (half)
36523652

3653-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657-
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3653+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3658+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36593659

36603660
// store the result for all queries in local memory (the O matrix from the paper)
36613661
o4_t lo[DV4/NL];
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
36843684
threadgroup_barrier(mem_flags::mem_threadgroup);
36853685

36863686
{
3687-
half S = 0.0f;
3688-
half M = -__FLT16_MAX__/2;
3687+
float S = 0.0f;
3688+
float M = -__FLT16_MAX__/2;
36893689

36903690
// thread indices inside the simdgroup
36913691
const short tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
37033703
// pointer to the mask
37043704
device const half * pm = (device const half *) (mask + iq1*args.nb31);
37053705

3706-
half slope = 1.0f;
3706+
float slope = 1.0f;
37073707

37083708
// ALiBi
37093709
if (args.max_bias > 0.0f) {
37103710
const short h = iq2;
37113711

3712-
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3712+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
37133713
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
37143714

37153715
slope = pow(base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
37993799

38003800
// online softmax
38013801
{
3802-
const half m = M;
3803-
const half s = ss[tiisg];
3802+
const float m = M;
3803+
const float s = ss[tiisg];
38043804

38053805
M = simd_max(max(M, s));
38063806

3807-
const half ms = exp(m - M);
3808-
const half vs = exp(s - M);
3807+
const float ms = exp(m - M);
3808+
const float vs = exp(s - M);
38093809

38103810
S = S*ms + simd_sum(vs);
38113811

@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
38363836
v4_t mv;
38373837
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
38383838

3839-
lo[ii/NL] += mv*ms;
3839+
lo[ii/NL] += o4_t(float4(mv)*float4(ms));
38403840
}
38413841
}
38423842
}
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
39073907
// parallel reduce
39083908
for (short r = nsg/2; r > 0; r >>= 1) {
39093909
if (sgitg < r) {
3910-
const half S0 = ss[ 0];
3911-
const half S1 = ss[r*SH + 0];
3910+
const float S0 = ss[ 0];
3911+
const float S1 = ss[r*(SH/2) + 0];
39123912

3913-
const half M0 = ss[ 1];
3914-
const half M1 = ss[r*SH + 1];
3913+
const float M0 = ss[ 1];
3914+
const float M1 = ss[r*(SH/2) + 1];
39153915

3916-
const half M = max(M0, M1);
3916+
const float M = max(M0, M1);
39173917

3918-
const half ms0 = exp(M0 - M);
3919-
const half ms1 = exp(M1 - M);
3918+
const float ms0 = exp(M0 - M);
3919+
const float ms1 = exp(M1 - M);
39203920

3921-
const half S = S0*ms0 + S1*ms1;
3921+
const float S = S0*ms0 + S1*ms1;
39223922

39233923
if (tiisg == 0) {
39243924
ss[0] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
39503950
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
39513951
//
39523952
#define FA_TYPES \
3953-
half4, \
3954-
half4, \
3955-
half4, \
3956-
float, \
3957-
half, half4, \
3953+
half4, \
3954+
half4, \
3955+
half4, \
3956+
float, \
3957+
float, float4, \
39583958
half4
39593959

39603960
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;

0 commit comments

Comments
 (0)