Skip to content

Commit 8074ca8

Browse files
committed
metal : small-batch mat-mul kernels
ggml-ci
1 parent f0678c5 commit 8074ca8

File tree

4 files changed

+402
-24
lines changed

4 files changed

+402
-24
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,30 @@ typedef struct {
192192
int16_t r3;
193193
} ggml_metal_kargs_mul_mv;
194194

195+
typedef struct {
196+
int32_t ne00;
197+
int32_t ne01;
198+
int32_t ne02;
199+
uint64_t nb00;
200+
uint64_t nb01;
201+
uint64_t nb02;
202+
uint64_t nb03;
203+
int32_t ne10;
204+
int32_t ne11;
205+
int32_t ne12;
206+
uint64_t nb10;
207+
uint64_t nb11;
208+
uint64_t nb12;
209+
uint64_t nb13;
210+
int32_t ne0;
211+
int32_t ne1;
212+
int16_t r2;
213+
int16_t r3;
214+
int16_t nsg;
215+
int16_t nxpsg;
216+
int16_t r1ptg;
217+
} ggml_metal_kargs_mul_mv_ext;
218+
195219
typedef struct {
196220
int32_t nei0;
197221
int32_t nei1;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 169 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175175
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176176
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177177
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
178202
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179203
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180204
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -699,6 +723,30 @@ @implementation GGMLMetalClass
699723
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700724
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701725
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
726+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
727+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
728+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
729+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
730+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
731+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
732+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
733+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
734+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
735+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
736+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
737+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
738+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
739+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
740+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
741+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
742+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
743+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
744+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
745+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
746+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
747+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
702750
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703751
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704752
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
19301978
// to the matrix-vector kernel
19311979
int ne11_mm_min = 4;
19321980

1933-
#if 0
1934-
// the numbers below are measured on M2 Ultra for 7B and 13B models
1935-
// these numbers do not translate to other devices or model sizes
1936-
// TODO: need to find a better approach
1937-
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1938-
switch (src0t) {
1939-
case GGML_TYPE_F16: ne11_mm_min = 2; break;
1940-
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1941-
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1942-
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1943-
case GGML_TYPE_Q4_0:
1944-
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1945-
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1946-
case GGML_TYPE_Q5_0: // not tested yet
1947-
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1948-
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1949-
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1950-
default: ne11_mm_min = 1; break;
1951-
}
1952-
}
1953-
#endif
1981+
if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982+
src0t == GGML_TYPE_Q4_0 ||
1983+
src0t == GGML_TYPE_Q4_1 ||
1984+
src0t == GGML_TYPE_Q5_0 ||
1985+
src0t == GGML_TYPE_Q5_1 ||
1986+
src0t == GGML_TYPE_Q8_0
1987+
) &&
1988+
src1t == GGML_TYPE_F32 &&
1989+
(ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990+
(ne11 >= 2 && ne11 <= 8)) {
1991+
1992+
// TODO: determine the optimal parameters based on grid utilization
1993+
const int nsg = 2; // TODO: or 4?
1994+
const int nxpsg = ne11 < 3 ? 16 : 8;
1995+
const int nypsg = 32/nxpsg;
1996+
const int r0ptg = nypsg*nsg;
1997+
int r1ptg = 4;
1998+
1999+
switch (ne11) {
2000+
case 2:
2001+
r1ptg = 2; break;
2002+
case 3:
2003+
case 6:
2004+
r1ptg = 3; break;
2005+
case 4:
2006+
case 7:
2007+
case 8:
2008+
r1ptg = 4; break;
2009+
case 5:
2010+
r1ptg = 5; break;
2011+
};
2012+
2013+
assert(nxpsg >= 8);
2014+
assert(nxpsg%8 == 0);
2015+
2016+
id<MTLComputePipelineState> pipeline = nil;
2017+
2018+
switch (src0->type) {
2019+
case GGML_TYPE_F16:
2020+
switch (r1ptg) {
2021+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
2022+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
2023+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
2024+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
2025+
default: GGML_ABORT("not implemented");
2026+
} break;
2027+
case GGML_TYPE_Q4_0:
2028+
switch (r1ptg) {
2029+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
2030+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
2031+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
2032+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
2033+
default: GGML_ABORT("not implemented");
2034+
} break;
2035+
case GGML_TYPE_Q4_1:
2036+
switch (r1ptg) {
2037+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
2038+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
2039+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
2040+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
2041+
default: GGML_ABORT("not implemented");
2042+
} break;
2043+
case GGML_TYPE_Q5_0:
2044+
switch (r1ptg) {
2045+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
2046+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
2047+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
2048+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
2049+
default: GGML_ABORT("not implemented");
2050+
} break;
2051+
case GGML_TYPE_Q5_1:
2052+
switch (r1ptg) {
2053+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
2054+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
2055+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
2056+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
2057+
default: GGML_ABORT("not implemented");
2058+
} break;
2059+
case GGML_TYPE_Q8_0:
2060+
switch (r1ptg) {
2061+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
2062+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
2063+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
2064+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
2065+
default: GGML_ABORT("not implemented");
2066+
} break;
2067+
default: GGML_ABORT("not implemented");
2068+
}
2069+
2070+
ggml_metal_kargs_mul_mv_ext args = {
2071+
/*.ne00 =*/ ne00,
2072+
/*.ne01 =*/ ne01,
2073+
/*.ne02 =*/ ne02,
2074+
/*.nb00 =*/ nb00,
2075+
/*.nb01 =*/ nb01,
2076+
/*.nb02 =*/ nb02,
2077+
/*.nb03 =*/ nb03,
2078+
/*.ne10 =*/ ne10,
2079+
/*.ne11 =*/ ne11,
2080+
/*.ne12 =*/ ne12,
2081+
/*.nb10 =*/ nb10,
2082+
/*.nb11 =*/ nb11,
2083+
/*.nb12 =*/ nb12,
2084+
/*.nb13 =*/ nb13,
2085+
/*.ne0 =*/ ne0,
2086+
/*.ne1 =*/ ne1,
2087+
/*.r2 =*/ r2,
2088+
/*.r3 =*/ r3,
2089+
/*.nsg =*/ nsg,
2090+
/*.nxpsg =*/ nxpsg,
2091+
/*.r1ptg =*/ r1ptg,
2092+
};
2093+
2094+
[encoder setComputePipelineState:pipeline];
2095+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2096+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2097+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2098+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
19542099

2100+
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2101+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2102+
} else
19552103
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
19562104
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
19572105
if ([device supportsFamily:MTLGPUFamilyApple7] &&

0 commit comments

Comments
 (0)