Skip to content

Commit a972fae

Browse files
authored
CUDA: Add mul_mat_id support for the mmf kernel (#15767)
* CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds
1 parent 550cf72 commit a972fae

23 files changed

+603
-350
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
4444
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4545
file(GLOB SRCS "template-instances/mmq*.cu")
4646
list(APPEND GGML_SOURCES_CUDA ${SRCS})
47+
file(GLOB SRCS "template-instances/mmf*.cu")
48+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4749

4850
if (GGML_CUDA_FA_ALL_QUANTS)
4951
file(GLOB SRCS "template-instances/fattn-vec*.cu")

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21092109
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
21102110
return;
21112111
}
2112+
2113+
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
2114+
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2115+
return;
2116+
}
21122117
}
21132118

21142119
cudaStream_t stream = ctx.stream();

ggml/src/ggml-cuda/mma.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
23
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
34
// The documentation for the PTX instructions can be found under:

ggml/src/ggml-cuda/mmf.cu

Lines changed: 34 additions & 348 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 469 additions & 1 deletion
Large diffs are not rendered by default.

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
DECL_MMQ_CASE({type});
3535
"""
3636

37+
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
38+
39+
#include "../mmf.cuh"
40+
41+
DECL_MMF_CASE({type});
42+
"""
43+
3744

3845
def get_short_name(long_quant_name):
3946
return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -76,3 +83,7 @@ def get_head_sizes(type_k, type_v):
7683
for type in TYPES_MMQ:
7784
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
7885
f.write(SOURCE_MMQ.format(type=type))
86+
87+
for type in range(1, 17):
88+
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
89+
f.write(SOURCE_MMF.format(type=type))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(1);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(10);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(11);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(12);

0 commit comments

Comments
 (0)