From f45c40e31c7ae043ff73478baf19121400b59426 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 Nov 2024 13:16:12 +0200 Subject: [PATCH 1/4] metal : small-batch mat-mul kernels ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 24 +++ ggml/src/ggml-metal/ggml-metal.m | 190 +++++++++++++++++++++--- ggml/src/ggml-metal/ggml-metal.metal | 203 +++++++++++++++++++++++++- tests/test-backend-ops.cpp | 9 ++ 4 files changed, 402 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53c13549650c8..e195867192712 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -192,6 +192,30 @@ typedef struct { int16_t r3; } ggml_metal_kargs_mul_mv; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; + int16_t nsg; + int16_t nxpsg; + int16_t r1ptg; +} ggml_metal_kargs_mul_mv_ext; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index ae6b25edea95a..d7ee7320cfdce 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -175,6 +175,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, @@ -699,6 +723,30 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); @@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node( // to the matrix-vector kernel int ne11_mm_min = 4; -#if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case GGML_TYPE_F16: ne11_mm_min = 2; break; - case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case GGML_TYPE_Q5_0: // not tested yet - case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } - } -#endif + if ((src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 + ) && + src1t == GGML_TYPE_F32 && + (ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8 + (ne11 >= 2 && ne11 <= 8)) { + + // TODO: determine the optimal parameters based on grid utilization + const int nsg = 2; // TODO: or 4? + const int nxpsg = ne11 < 3 ? 16 : 8; + const int nypsg = 32/nxpsg; + const int r0ptg = nypsg*nsg; + int r1ptg = 4; + + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + }; + + assert(nxpsg >= 8); + assert(nxpsg%8 == 0); + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F16: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q8_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + /*.nsg =*/ nsg, + /*.nxpsg =*/ nxpsg, + /*.r1ptg =*/ r1ptg, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index eaca38864bd65..50f33b6c6e8a9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src + il)); +} + #if defined(GGML_METAL_USE_BF16) template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { @@ -55,7 +60,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re #endif template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -73,8 +78,23 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md; + } +} + template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { +void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -92,6 +112,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m; + } +} + template void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 3); @@ -124,6 +159,14 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { + // TODO: implement + float4x4 tmp; + dequantize_q5_0(xb, il/4, tmp); + reg = tmp[il%4]; +} + template void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 4); @@ -156,10 +199,18 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { + // TODO: implement + float4x4 tmp; + dequantize_q5_1(xb, il/4, tmp); + reg = tmp[il%4]; +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; + const float d = xb->d; float4x4 reg_f; @@ -170,6 +221,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -1752,6 +1813,142 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_ext_q_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 4; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + (tx/8)*bp32 : (device const q_t *) src0; + + device const float4 * y4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + float4 lx; + + deq_t4(xq + (4*ch*nxpsg)/(32/bp32), tx%8, lx); + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += dot(lx, y4[ir1][ch*nxpsg]); + } + } + + xq += ((4*chpt)*nxpsg)/(32/bp32); + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] += ((4*chpt)*nxpsg)/4; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +template +kernel void kernel_mul_mv_ext_q_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 8: kernel_mul_mv_ext_q_f32_impl<8, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q_f32_impl<16, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q_f32_impl<32, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +typedef decltype(kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q_f32_t; + +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, half4, 8, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, half4, 8, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, half4, 8, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, half4, 8, dequantize_f16_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_0, 1, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_0, 1, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_0, 1, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_0, 1, dequantize_q4_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_1, 1, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_1, 1, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_1, 1, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_1, 1, dequantize_q4_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_0, 1, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_0, 1, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_0, 1, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_0, 1, dequantize_q5_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_1, 1, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_1, 1, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_1, 1, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_1, 1, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 1, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q8_0, 1, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q8_0, 1, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q8_0, 1, dequantize_q8_0_t4>; + #define N_MV_T_T 4 template diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4d9df1a645bbb..5aef6c9d131d8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3572,6 +3572,15 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); + for (int i = 1; i < 64; ++i) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + } + #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { From 677ee9ff5926335693ca74775d87063da09a40f9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Dec 2024 11:08:16 +0200 Subject: [PATCH 2/4] metal : add rest of types ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 100 +++++++-- ggml/src/ggml-metal/ggml-metal.metal | 318 ++++++++++++++++++++++----- tests/test-backend-ops.cpp | 16 +- 3 files changed, 360 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d7ee7320cfdce..0aec92a26b5e7 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -199,6 +199,22 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, @@ -747,6 +763,22 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); @@ -1978,17 +2010,28 @@ static void ggml_metal_encode_node( // to the matrix-vector kernel int ne11_mm_min = 4; - if ((src0t == GGML_TYPE_F16 || // TODO: helper function - src0t == GGML_TYPE_Q4_0 || - src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || - src0t == GGML_TYPE_Q8_0 - ) && - src1t == GGML_TYPE_F32 && - (ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8 - (ne11 >= 2 && ne11 <= 8)) { - + if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && + ( + ( + ( + src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 12) + ) + ) + ) { // TODO: determine the optimal parameters based on grid utilization const int nsg = 2; // TODO: or 4? const int nxpsg = ne11 < 3 ? 16 : 8; @@ -2010,9 +2053,6 @@ static void ggml_metal_encode_node( r1ptg = 5; break; }; - assert(nxpsg >= 8); - assert(nxpsg%8 == 0); - id pipeline = nil; switch (src0->type) { @@ -2064,6 +2104,38 @@ static void ggml_metal_encode_node( case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; default: GGML_ABORT("not implemented"); } break; + case GGML_TYPE_Q4_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q6_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_IQ4_NL: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; default: GGML_ABORT("not implemented"); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 50f33b6c6e8a9..3d15897f4ba66 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -128,7 +128,7 @@ void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & r } template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { +void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 3); const float d = xb->d; const float md = -16.h * xb->d; @@ -161,14 +161,36 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg template void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { - // TODO: implement - float4x4 tmp; - dequantize_q5_0(xb, il/4, tmp); - reg = tmp[il%4]; + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + md; + reg[2*ii + 1] = d * x1 + md; + } } template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { +void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 4); const float d = xb->d; const float m = xb->m; @@ -201,10 +223,32 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg template void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { - // TODO: implement - float4x4 tmp; - dequantize_q5_1(xb, il/4, tmp); - reg = tmp[il%4]; + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + m; + reg[2*ii + 1] = d * x1 + m; + } } template @@ -285,7 +329,7 @@ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q } template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { +void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { device const uchar * q = xb->qs; short is = (il/4) * 2; @@ -297,7 +341,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float dl = d * sc[0]; const float ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; + const ushort mask = il < 2 ? 0x0F : 0xF0; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - ml; } @@ -530,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 } } +template +void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; + reg[0] = d * kvalues_iq4nl_f[q8[0]]; + reg[1] = d * kvalues_iq4nl_f[q8[1]]; + reg[2] = d * kvalues_iq4nl_f[q8[2]]; + reg[3] = d * kvalues_iq4nl_f[q8[3]]; +} + template void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 @@ -1813,8 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template -void kernel_mul_mv_ext_q_f32_impl( +template +void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, device const char * src1, @@ -1840,7 +1897,7 @@ void kernel_mul_mv_ext_q_f32_impl( const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + (tx/8)*bp32 : (device const q_t *) src0; + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; device const float4 * y4[r1ptg]; @@ -1850,23 +1907,34 @@ void kernel_mul_mv_ext_q_f32_impl( float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; - for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { + short cch = tx%chpb; + + for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { + float4 lx[chpt]; + #pragma unroll(chpt) for (short ch = 0; ch < chpt; ++ch) { - float4 lx; + deq_t4(xq, cch, lx[ch]); - deq_t4(xq + (4*ch*nxpsg)/(32/bp32), tx%8, lx); + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { #pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { - sumf[ir1] += dot(lx, y4[ir1][ch*nxpsg]); + sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); + } } - xq += ((4*chpt)*nxpsg)/(32/bp32); - +#pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { - y4[ir1] += ((4*chpt)*nxpsg)/4; + y4[ir1] += chpt*nxpsg; } } @@ -1901,8 +1969,8 @@ void kernel_mul_mv_ext_q_f32_impl( } } -template -kernel void kernel_mul_mv_ext_q_f32_disp( +template +void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, device const char * src1, @@ -1910,44 +1978,186 @@ kernel void kernel_mul_mv_ext_q_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 8: kernel_mul_mv_ext_q_f32_impl<8, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q_f32_impl<16, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q_f32_impl<32, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + const short chpt = 1; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4x4 * y4x4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; } -} -typedef decltype(kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q_f32_t; + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; + + for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { + float4x4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4x4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, half4, 8, dequantize_f16_t4>; +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += + dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + + dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + + dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + + dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] += chpt*nxpsg; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_0, 1, dequantize_q4_0_t4>; + //sumf[ir1] = simd_sum(sumf[ir1]); + } -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_1, 1, dequantize_q4_1_t4>; + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_0, 1, dequantize_q5_0_t4>; + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_1, 1, dequantize_q5_1_t4>; +template +kernel void kernel_mul_mv_ext_q4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q8_0, 1, dequantize_q8_0_t4>; +template +kernel void kernel_mul_mv_ext_q4x4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; + +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; + +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; + +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; + +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; #define N_MV_T_T 4 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5aef6c9d131d8..7bd1cb5e0ffa0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3573,12 +3573,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); for (int i = 1; i < 64; ++i) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); } #if 1 From 5590160cd6f75b8b248b1bd6d08fc02d4d06892b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Dec 2024 20:10:20 +0200 Subject: [PATCH 3/4] metal : final adjustments ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 2 +- tests/test-backend-ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 0aec92a26b5e7..b7434da3e39fe 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2028,7 +2028,7 @@ static void ggml_metal_encode_node( src0t == GGML_TYPE_Q4_K || src0t == GGML_TYPE_Q5_K || src0t == GGML_TYPE_Q6_K || - false) && (ne11 >= 4 && ne11 <= 12) + false) && (ne11 >= 4 && ne11 <= 8) ) ) ) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7bd1cb5e0ffa0..c786da4c35d7c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3572,7 +3572,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); - for (int i = 1; i < 64; ++i) { + for (int i = 1; i < 9; ++i) { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); From 434fc452c3bea02c53ea2074543ca66f6f3749f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Dec 2024 11:39:33 +0200 Subject: [PATCH 4/4] metal : add comments ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 22 ++++++++++++++++------ ggml/src/ggml-metal/ggml-metal.metal | 10 ++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b7434da3e39fe..c247b50c9e690 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2008,8 +2008,10 @@ static void ggml_metal_encode_node( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 4; + const int ne11_mm_min = 4; + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && ( ( @@ -2033,12 +2035,20 @@ static void ggml_metal_encode_node( ) ) { // TODO: determine the optimal parameters based on grid utilization - const int nsg = 2; // TODO: or 4? - const int nxpsg = ne11 < 3 ? 16 : 8; - const int nypsg = 32/nxpsg; - const int r0ptg = nypsg*nsg; - int r1ptg = 4; + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int r1ptg = 4; // num src1 rows per threadgroup + // note: not sure how optimal are those across all different hardware. there might be someting cleverer switch (ne11) { case 2: r1ptg = 2; break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3d15897f4ba66..7567f326200fc 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +// mat-vec kernel processing in chunks of float4 +// chpb - chunks per quantization block template void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, @@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short chpt = 4; + const short chpt = 4; // chunks per thread //const short nxpsg = (32); const short nypsg = (32/nxpsg); @@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl( float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; - short cch = tx%chpb; + short cch = tx%chpb; // current chunk index for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { float4 lx[chpt]; @@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } } + // reduce only the threads in each row for (short ir1 = 0; ir1 < r1ptg; ++ir1) { if (nxpsg >= 32) { sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); @@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } } +// mat-vec kernel processing in chunks of float4x4 template void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, @@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl( } } +// dispatchers needed for compile-time nxpsg +// epb - elements per quantization block template kernel void kernel_mul_mv_ext_q4_f32_disp( constant ggml_metal_kargs_mul_mv_ext & args,