File tree Expand file tree Collapse file tree 4 files changed +1404
-1314
lines changed Expand file tree Collapse file tree 4 files changed +1404
-1314
lines changed Original file line number Diff line number Diff line change 20
20
#define N_R0_Q5_1 4
21
21
#define N_SG_Q5_1 2
22
22
23
- #define N_R0_Q8_0 4
24
- #define N_SG_Q8_0 2
23
+ #define N_R0_Q8_0 2
24
+ #define N_SG_Q8_0 4
25
25
26
26
#define N_R0_MXFP4 2
27
27
#define N_SG_MXFP4 2
68
68
#define N_R0_IQ4_XS 2
69
69
#define N_SG_IQ4_XS 2
70
70
71
+ // function constants offsets
72
+ #define FC_FLASH_ATTN_EXT 100
73
+ #define FC_FLASH_ATTN_EXT_VEC 200
74
+ #define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75
+
71
76
// kernel argument structs
72
77
//
73
78
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
236
241
int32_t ne11 ;
237
242
int32_t ne_12_2 ; // assume K and V are same shape
238
243
int32_t ne_12_3 ;
244
+ int32_t ns10 ;
239
245
uint64_t nb11 ;
240
246
uint64_t nb12 ;
241
247
uint64_t nb13 ;
248
+ int32_t ns20 ;
242
249
uint64_t nb21 ;
243
250
uint64_t nb22 ;
244
251
uint64_t nb23 ;
@@ -258,10 +265,43 @@ typedef struct {
258
265
float logit_softcap ;
259
266
} ggml_metal_kargs_flash_attn_ext ;
260
267
268
+ typedef struct {
269
+ int32_t ne01 ;
270
+ int32_t ne02 ;
271
+ int32_t ne03 ;
272
+ uint64_t nb01 ;
273
+ uint64_t nb02 ;
274
+ uint64_t nb03 ;
275
+ int32_t ne11 ;
276
+ int32_t ne_12_2 ; // assume K and V are same shape
277
+ int32_t ne_12_3 ;
278
+ int32_t ns10 ;
279
+ uint64_t nb11 ;
280
+ uint64_t nb12 ;
281
+ uint64_t nb13 ;
282
+ int32_t ns20 ;
283
+ uint64_t nb21 ;
284
+ uint64_t nb22 ;
285
+ uint64_t nb23 ;
286
+ int32_t ne32 ;
287
+ int32_t ne33 ;
288
+ uint64_t nb31 ;
289
+ uint64_t nb32 ;
290
+ uint64_t nb33 ;
291
+ int32_t ne1 ;
292
+ int32_t ne2 ;
293
+ int32_t ne3 ;
294
+ float scale ;
295
+ float max_bias ;
296
+ float m0 ;
297
+ float m1 ;
298
+ int32_t n_head_log2 ;
299
+ float logit_softcap ;
300
+ } ggml_metal_kargs_flash_attn_ext_vec ;
301
+
261
302
typedef struct {
262
303
int32_t nrows ;
263
- int32_t ne20 ;
264
- } ggml_metal_kargs_flash_attn_ext_reduce ;
304
+ } ggml_metal_kargs_flash_attn_ext_vec_reduce ;
265
305
266
306
typedef struct {
267
307
int32_t ne00 ;
You can’t perform that action at this time.
0 commit comments