@@ -72,6 +72,36 @@ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
72
72
layout (binding = 3) readonly buffer M {float16_t data_m[];};
73
73
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
74
74
75
+ #if defined(A_TYPE_PACKED16)
76
+ #define BINDING_IDX_K 0
77
+ #define BINDING_IDX_V 1
78
+ layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
79
+ #endif
80
+
81
+ #if defined(DATA_A_Q4_0)
82
+ #define BLOCK_BYTE_SIZE 18
83
+
84
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
85
+ uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
86
+ uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
87
+ uint shift = (iqs & 0x10) >> 2;
88
+ vui_lo >>= shift;
89
+ vui_hi >>= shift;
90
+
91
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
92
+ }
93
+ #endif
94
+
95
+ #if defined(DATA_A_Q8_0)
96
+ #define BLOCK_BYTE_SIZE 34
97
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
98
+ const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
99
+ const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
100
+
101
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
102
+ }
103
+ #endif
104
+
75
105
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
76
106
77
107
// Store the output when doing grouped query attention.
@@ -208,6 +238,14 @@ void main() {
208
238
}
209
239
}
210
240
241
+ #if BLOCK_SIZE > 1
242
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
243
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
244
+ #else
245
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
246
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
247
+ #endif
248
+
211
249
[[dont_unroll]]
212
250
for (uint32_t j = start_j; j < end_j; ++j) {
213
251
@@ -218,11 +256,17 @@ void main() {
218
256
}
219
257
}
220
258
221
- uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
222
259
223
260
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
224
261
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
262
+ #if BLOCK_SIZE > 1
263
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
264
+ uint ib = coord / BLOCK_SIZE;
265
+ uint iqs = (coord % BLOCK_SIZE);
266
+ vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
267
+ #else
225
268
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
269
+ #endif
226
270
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
227
271
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
228
272
}
@@ -297,11 +341,16 @@ void main() {
297
341
}
298
342
}
299
343
300
- uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
301
-
302
344
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
303
345
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
346
+ #if BLOCK_SIZE > 1
347
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
348
+ uint ib = coord / BLOCK_SIZE;
349
+ uint iqs = (coord % BLOCK_SIZE);
350
+ vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
351
+ #else
304
352
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
353
+ #endif
305
354
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
306
355
Of[r][d] += Pf[r][c] * Vf;
307
356
}
0 commit comments