Skip to content

Commit 2c9f833

Browse files
authored
mat vec double buffer (#12188)
1 parent 2513645 commit 2c9f833

File tree

3 files changed

+43
-42
lines changed

3 files changed

+43
-42
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,33 @@
55

66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
9-
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
8+
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
9+
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
1010

1111
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+
uint csel = 0;
1213

1314
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1415
const uint y_idx = i * QUANT_K + y_offset;
1516

1617
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1718
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
19+
csel ^= 1;
1820

19-
barrier();
2021
if (!all_threads) { // when we don't have enough blocks to use all threads
2122
if (i < num_blocks_per_row) {
2223
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
23-
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
24-
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
24+
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
25+
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
2526
}
2627
barrier();
2728

2829
if (i >= num_blocks_per_row)
2930
continue;
3031
} else {
3132
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
32-
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
33-
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
33+
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
34+
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
3435
barrier();
3536
}
3637

@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
5758
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
5859
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
5960
[[unroll]] for (int l = 0; l < 2; ++l) {
60-
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
61-
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
62-
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
63-
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
64-
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
65-
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
66-
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
67-
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
68-
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
69-
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
70-
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
71-
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
72-
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
73-
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
74-
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
75-
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
61+
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
62+
fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
63+
fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
64+
fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
65+
fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
66+
fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
67+
fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
68+
fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
69+
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
70+
fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
71+
fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
72+
fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
73+
fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
74+
fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
75+
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
76+
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
7677
}
7778
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
7879
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55

66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
8+
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
99

1010
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
11+
uint csel = 0;
1112

1213
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1314
const uint y_idx = i * QUANT_K + y_offset;
1415

1516
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1617
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
18+
csel ^= 1;
1719

1820
if (!all_threads) { // when we don't have enough blocks to use all threads
19-
barrier();
2021
if (i < num_blocks_per_row)
21-
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
22+
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
2223
barrier();
2324

2425
if (i >= num_blocks_per_row)
@@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
4041
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
4142

4243
if (all_threads) {
43-
barrier();
44-
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
44+
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
4545
barrier();
4646
}
4747

@@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
5959

6060
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
6161
[[unroll]] for (int l = 0; l < 2; ++l) {
62-
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
63-
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
64-
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
65-
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
66-
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
67-
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
68-
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
69-
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
62+
sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
63+
fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
64+
fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
65+
fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
66+
fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
67+
fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
68+
fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
69+
fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
7070
}
7171
temp[j][n] = fma(d, sum, temp[j][n]);
7272
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66

77
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

9-
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
9+
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
1010

1111
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+
uint csel = 0;
1213

1314
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1415
const uint y_idx = i * QUANT_K + y_offset;
1516

1617
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1718
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
19+
csel ^= 1;
1820

1921
if (!all_threads) { // when we don't have enough blocks to use all threads
20-
barrier();
2122
if (i < num_blocks_per_row)
22-
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
23+
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
2324
barrier();
2425

2526
if (i >= num_blocks_per_row)
@@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
5152
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
5253

5354
if (all_threads) {
54-
barrier();
55-
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
55+
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
5656
barrier();
5757
}
5858

@@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
7171
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
7272
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
7373
}
74-
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
74+
temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
7575
}
7676
}
7777
}

0 commit comments

Comments
 (0)