@@ -175,6 +175,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175
175
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
176
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
177
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
178
202
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179
203
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180
204
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -699,6 +723,30 @@ @implementation GGMLMetalClass
699
723
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700
724
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701
725
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
726
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
727
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
728
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
729
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
730
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
731
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
732
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
733
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
734
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
735
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
736
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
737
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
738
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
739
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
740
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
741
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
742
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
743
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
744
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
745
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
746
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
747
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
702
750
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703
751
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704
752
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
1930
1978
// to the matrix-vector kernel
1931
1979
int ne11_mm_min = 4;
1932
1980
1933
- #if 0
1934
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1935
- // these numbers do not translate to other devices or model sizes
1936
- // TODO: need to find a better approach
1937
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1938
- switch (src0t) {
1939
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1940
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1941
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1942
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1943
- case GGML_TYPE_Q4_0:
1944
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1945
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1946
- case GGML_TYPE_Q5_0: // not tested yet
1947
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1948
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1949
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1950
- default: ne11_mm_min = 1; break;
1951
- }
1952
- }
1953
- #endif
1981
+ if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982
+ src0t == GGML_TYPE_Q4_0 ||
1983
+ src0t == GGML_TYPE_Q4_1 ||
1984
+ src0t == GGML_TYPE_Q5_0 ||
1985
+ src0t == GGML_TYPE_Q5_1 ||
1986
+ src0t == GGML_TYPE_Q8_0
1987
+ ) &&
1988
+ src1t == GGML_TYPE_F32 &&
1989
+ (ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990
+ (ne11 >= 2 && ne11 <= 8)) {
1991
+
1992
+ // TODO: determine the optimal parameters based on grid utilization
1993
+ const int nsg = 2; // TODO: or 4?
1994
+ const int nxpsg = ne11 < 3 ? 16 : 8;
1995
+ const int nypsg = 32/nxpsg;
1996
+ const int r0ptg = nypsg*nsg;
1997
+ int r1ptg = 4;
1998
+
1999
+ switch (ne11) {
2000
+ case 2:
2001
+ r1ptg = 2; break;
2002
+ case 3:
2003
+ case 6:
2004
+ r1ptg = 3; break;
2005
+ case 4:
2006
+ case 7:
2007
+ case 8:
2008
+ r1ptg = 4; break;
2009
+ case 5:
2010
+ r1ptg = 5; break;
2011
+ };
2012
+
2013
+ assert(nxpsg >= 8);
2014
+ assert(nxpsg%8 == 0);
2015
+
2016
+ id<MTLComputePipelineState> pipeline = nil;
2017
+
2018
+ switch (src0->type) {
2019
+ case GGML_TYPE_F16:
2020
+ switch (r1ptg) {
2021
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
2022
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
2023
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
2024
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
2025
+ default: GGML_ABORT("not implemented");
2026
+ } break;
2027
+ case GGML_TYPE_Q4_0:
2028
+ switch (r1ptg) {
2029
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
2030
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
2031
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
2032
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
2033
+ default: GGML_ABORT("not implemented");
2034
+ } break;
2035
+ case GGML_TYPE_Q4_1:
2036
+ switch (r1ptg) {
2037
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
2038
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
2039
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
2040
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
2041
+ default: GGML_ABORT("not implemented");
2042
+ } break;
2043
+ case GGML_TYPE_Q5_0:
2044
+ switch (r1ptg) {
2045
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
2046
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
2047
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
2048
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
2049
+ default: GGML_ABORT("not implemented");
2050
+ } break;
2051
+ case GGML_TYPE_Q5_1:
2052
+ switch (r1ptg) {
2053
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
2054
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
2055
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
2056
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
2057
+ default: GGML_ABORT("not implemented");
2058
+ } break;
2059
+ case GGML_TYPE_Q8_0:
2060
+ switch (r1ptg) {
2061
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
2062
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
2063
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
2064
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
2065
+ default: GGML_ABORT("not implemented");
2066
+ } break;
2067
+ default: GGML_ABORT("not implemented");
2068
+ }
2069
+
2070
+ ggml_metal_kargs_mul_mv_ext args = {
2071
+ /*.ne00 =*/ ne00,
2072
+ /*.ne01 =*/ ne01,
2073
+ /*.ne02 =*/ ne02,
2074
+ /*.nb00 =*/ nb00,
2075
+ /*.nb01 =*/ nb01,
2076
+ /*.nb02 =*/ nb02,
2077
+ /*.nb03 =*/ nb03,
2078
+ /*.ne10 =*/ ne10,
2079
+ /*.ne11 =*/ ne11,
2080
+ /*.ne12 =*/ ne12,
2081
+ /*.nb10 =*/ nb10,
2082
+ /*.nb11 =*/ nb11,
2083
+ /*.nb12 =*/ nb12,
2084
+ /*.nb13 =*/ nb13,
2085
+ /*.ne0 =*/ ne0,
2086
+ /*.ne1 =*/ ne1,
2087
+ /*.r2 =*/ r2,
2088
+ /*.r3 =*/ r3,
2089
+ /*.nsg =*/ nsg,
2090
+ /*.nxpsg =*/ nxpsg,
2091
+ /*.r1ptg =*/ r1ptg,
2092
+ };
2093
+
2094
+ [encoder setComputePipelineState:pipeline];
2095
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2096
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2097
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2098
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1954
2099
2100
+ //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2101
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2102
+ } else
1955
2103
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1956
2104
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1957
2105
if ([device supportsFamily:MTLGPUFamilyApple7] &&
0 commit comments