Skip to content

Commit 616bca3

Browse files
committed
vulkan: Use larger loads in scalar/coopmat1 matmul
I think glslang will translate an access like x[i][1].z to OpAccessChain ... x, i, 1, 2 OpLoad float16_t ... rather than loading all of x[i] in a single OpLoad. Change the code to explicitly load the vector/matrix.
1 parent d4d8dbe commit 616bca3

File tree

3 files changed

+57
-34
lines changed

3 files changed

+57
-34
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,23 @@ void main() {
315315
#if LOAD_VEC_A == 8
316316
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
317317
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
318-
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
319-
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
320-
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
321-
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
322-
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
323-
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
324-
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
325-
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
318+
A_TYPE32 aa = A_TYPE32(data_a[idx]);
319+
buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x);
320+
buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
321+
buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
322+
buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
323+
buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
324+
buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
325+
buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
326+
buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
326327
#elif LOAD_VEC_A == 4
327328
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
328329
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
329-
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
330-
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
331-
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
332-
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
330+
A_TYPE32 aa = A_TYPE32(data_a[idx]);
331+
buf_a[buf_idx ] = FLOAT_TYPE(aa.x);
332+
buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
333+
buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
334+
buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
333335
#else
334336
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
335337
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
@@ -808,14 +810,19 @@ void main() {
808810
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
809811
#endif
810812
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
811-
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
812-
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
813-
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
814-
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
815-
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
816-
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
817-
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
818-
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
813+
#if defined(DATA_B_BF16)
814+
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
815+
#else
816+
B_TYPE32 bb = B_TYPE32(data_b[idx]);
817+
#endif
818+
buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
819+
buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
820+
buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
821+
buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
822+
buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
823+
buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
824+
buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
825+
buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
819826
#elif LOAD_VEC_B == 4
820827
#ifdef MUL_MAT_ID
821828
const u16vec2 row_idx = row_ids[loadc_b + l];
@@ -824,10 +831,15 @@ void main() {
824831
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
825832
#endif
826833
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
827-
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
828-
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
829-
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
830-
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
834+
#if defined(DATA_B_BF16)
835+
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
836+
#else
837+
B_TYPE32 bb = B_TYPE32(data_b[idx]);
838+
#endif
839+
buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
840+
buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
841+
buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
842+
buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
831843
#elif !MUL_MAT_ID
832844
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
833845
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
1515
#define A_TYPE float
16+
#define A_TYPE32 float
1617
#elif LOAD_VEC_A == 4
1718
#define A_TYPE vec4
19+
#define A_TYPE32 vec4
1820
#elif LOAD_VEC_A == 8
1921
#define A_TYPE mat2x4
22+
#define A_TYPE32 mat2x4
2023
#endif
2124
#endif
2225

@@ -26,10 +29,13 @@
2629

2730
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
2831
#define A_TYPE float16_t
32+
#define A_TYPE32 float
2933
#elif LOAD_VEC_A == 4
3034
#define A_TYPE f16vec4
35+
#define A_TYPE32 vec4
3136
#elif LOAD_VEC_A == 8
3237
#define A_TYPE f16mat2x4
38+
#define A_TYPE32 mat2x4
3339
#endif
3440
#endif
3541

@@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u)
14241430
return uintBitsToFloat(u << 16);
14251431
}
14261432

1433+
vec4 bf16_to_fp32(uvec4 u)
1434+
{
1435+
return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));
1436+
}
1437+
14271438
float e8m0_to_fp32(uint8_t x) {
14281439
uint32_t bits;
14291440

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
364364
};
365365

366366
// Shaders with f16 B_TYPE
367-
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
368-
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
367+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
368+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
369369

370-
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
371-
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
370+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
371+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
372372

373373
// bf16
374374
{
@@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
384384
if (!(coopmat || coopmat2))
385385
#endif
386386
{
387-
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
388-
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
387+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
388+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
389389
}
390390
}
391391

@@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
408408

409409
// don't generate f32 variants for coopmat2
410410
if (!coopmat2) {
411-
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
412-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
411+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
412+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
413413
}
414414

415415
if (tname != "f16" && tname != "f32") {
416-
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
417-
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
416+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
417+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
418418
}
419419

420420
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)

0 commit comments

Comments
 (0)