Skip to content

Commit 4a3156d

Browse files
ikawrakowKawrakow
andauthored
CUDA: faster dequantize kernels for Q4_0 and Q4_1 (#4938)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent a836c8f commit 4a3156d

File tree

1 file changed

+73
-4
lines changed

1 file changed

+73
-4
lines changed

ggml-cuda.cu

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
11051105
#endif // GGML_CUDA_F16
11061106
}
11071107

1108+
template<typename dst_t>
1109+
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1110+
1111+
const int i = blockIdx.x;
1112+
1113+
// assume 32 threads
1114+
const int tid = threadIdx.x;
1115+
const int il = tid/8;
1116+
const int ir = tid%8;
1117+
const int ib = 8*i + ir;
1118+
if (ib >= nb32) {
1119+
return;
1120+
}
1121+
1122+
dst_t * y = yy + 256*i + 32*ir + 4*il;
1123+
1124+
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
1125+
const float d = __half2float(x->d);
1126+
const float dm = -8*d;
1127+
1128+
const uint8_t * q = x->qs + 4*il;
1129+
1130+
for (int l = 0; l < 4; ++l) {
1131+
y[l+ 0] = d * (q[l] & 0xF) + dm;
1132+
y[l+16] = d * (q[l] >> 4) + dm;
1133+
}
1134+
}
1135+
1136+
template<typename dst_t>
1137+
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1138+
1139+
const int i = blockIdx.x;
1140+
1141+
// assume 32 threads
1142+
const int tid = threadIdx.x;
1143+
const int il = tid/8;
1144+
const int ir = tid%8;
1145+
const int ib = 8*i + ir;
1146+
if (ib >= nb32) {
1147+
return;
1148+
}
1149+
1150+
dst_t * y = yy + 256*i + 32*ir + 4*il;
1151+
1152+
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
1153+
const float2 d = __half22float2(x->dm);
1154+
1155+
const uint8_t * q = x->qs + 4*il;
1156+
1157+
for (int l = 0; l < 4; ++l) {
1158+
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
1159+
y[l+16] = d.x * (q[l] >> 4) + d.y;
1160+
}
1161+
}
1162+
11081163
//================================== k-quants
11091164

11101165
template<typename dst_t>
@@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
62536308
#endif
62546309
}
62556310

6311+
template<typename dst_t>
6312+
static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6313+
const int nb32 = k / 32;
6314+
const int nb = (k + 255) / 256;
6315+
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
6316+
}
6317+
6318+
template<typename dst_t>
6319+
static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6320+
const int nb32 = k / 32;
6321+
const int nb = (k + 255) / 256;
6322+
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
6323+
}
6324+
62566325
template<typename dst_t>
62576326
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
62586327
const int nb = k / QK_K;
@@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
63016370
int id;
63026371
switch (type) {
63036372
case GGML_TYPE_Q4_0:
6304-
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
6373+
return dequantize_q4_0_cuda;
63056374
case GGML_TYPE_Q4_1:
6306-
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
6375+
return dequantize_q4_1_cuda;
63076376
case GGML_TYPE_Q5_0:
63086377
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
63096378
case GGML_TYPE_Q5_1:
@@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
63386407
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
63396408
switch (type) {
63406409
case GGML_TYPE_Q4_0:
6341-
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
6410+
return dequantize_q4_0_cuda;
63426411
case GGML_TYPE_Q4_1:
6343-
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
6412+
return dequantize_q4_1_cuda;
63446413
case GGML_TYPE_Q5_0:
63456414
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
63466415
case GGML_TYPE_Q5_1:

0 commit comments

Comments
 (0)