Skip to content

Commit f28d4f4

Browse files
authored
metal : refactor + optimize (#15857)
* metal : refactor ggml-ci * cont : refactor FA-vec kernel * cont : print metal library load time * minor : warn to debug + bettern kernel names ggml-ci * metal : optimize mul_mv q8_0 ggml-ci * metal : simplify FA pipeline creation functions ggml-ci * metal : improve naming consistency * metal : safer function constants offsets ggml-ci * metal : comments ggml-ci
1 parent 9fcb29f commit f28d4f4

File tree

4 files changed

+1404
-1314
lines changed

4 files changed

+1404
-1314
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#define N_R0_Q5_1 4
2121
#define N_SG_Q5_1 2
2222

23-
#define N_R0_Q8_0 4
24-
#define N_SG_Q8_0 2
23+
#define N_R0_Q8_0 2
24+
#define N_SG_Q8_0 4
2525

2626
#define N_R0_MXFP4 2
2727
#define N_SG_MXFP4 2
@@ -68,6 +68,11 @@
6868
#define N_R0_IQ4_XS 2
6969
#define N_SG_IQ4_XS 2
7070

71+
// function constants offsets
72+
#define FC_FLASH_ATTN_EXT 100
73+
#define FC_FLASH_ATTN_EXT_VEC 200
74+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75+
7176
// kernel argument structs
7277
//
7378
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
236241
int32_t ne11;
237242
int32_t ne_12_2; // assume K and V are same shape
238243
int32_t ne_12_3;
244+
int32_t ns10;
239245
uint64_t nb11;
240246
uint64_t nb12;
241247
uint64_t nb13;
248+
int32_t ns20;
242249
uint64_t nb21;
243250
uint64_t nb22;
244251
uint64_t nb23;
@@ -258,10 +265,43 @@ typedef struct {
258265
float logit_softcap;
259266
} ggml_metal_kargs_flash_attn_ext;
260267

268+
typedef struct {
269+
int32_t ne01;
270+
int32_t ne02;
271+
int32_t ne03;
272+
uint64_t nb01;
273+
uint64_t nb02;
274+
uint64_t nb03;
275+
int32_t ne11;
276+
int32_t ne_12_2; // assume K and V are same shape
277+
int32_t ne_12_3;
278+
int32_t ns10;
279+
uint64_t nb11;
280+
uint64_t nb12;
281+
uint64_t nb13;
282+
int32_t ns20;
283+
uint64_t nb21;
284+
uint64_t nb22;
285+
uint64_t nb23;
286+
int32_t ne32;
287+
int32_t ne33;
288+
uint64_t nb31;
289+
uint64_t nb32;
290+
uint64_t nb33;
291+
int32_t ne1;
292+
int32_t ne2;
293+
int32_t ne3;
294+
float scale;
295+
float max_bias;
296+
float m0;
297+
float m1;
298+
int32_t n_head_log2;
299+
float logit_softcap;
300+
} ggml_metal_kargs_flash_attn_ext_vec;
301+
261302
typedef struct {
262303
int32_t nrows;
263-
int32_t ne20;
264-
} ggml_metal_kargs_flash_attn_ext_reduce;
304+
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
265305

266306
typedef struct {
267307
int32_t ne00;

0 commit comments

Comments
 (0)