@@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
1105
1105
#endif // GGML_CUDA_F16
1106
1106
}
1107
1107
1108
+ template <typename dst_t >
1109
+ static __global__ void dequantize_block_q4_0 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1110
+
1111
+ const int i = blockIdx .x ;
1112
+
1113
+ // assume 32 threads
1114
+ const int tid = threadIdx .x ;
1115
+ const int il = tid/8 ;
1116
+ const int ir = tid%8 ;
1117
+ const int ib = 8 *i + ir;
1118
+ if (ib >= nb32) {
1119
+ return ;
1120
+ }
1121
+
1122
+ dst_t * y = yy + 256 *i + 32 *ir + 4 *il;
1123
+
1124
+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
1125
+ const float d = __half2float (x->d );
1126
+ const float dm = -8 *d;
1127
+
1128
+ const uint8_t * q = x->qs + 4 *il;
1129
+
1130
+ for (int l = 0 ; l < 4 ; ++l) {
1131
+ y[l+ 0 ] = d * (q[l] & 0xF ) + dm;
1132
+ y[l+16 ] = d * (q[l] >> 4 ) + dm;
1133
+ }
1134
+ }
1135
+
1136
+ template <typename dst_t >
1137
+ static __global__ void dequantize_block_q4_1 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
1138
+
1139
+ const int i = blockIdx .x ;
1140
+
1141
+ // assume 32 threads
1142
+ const int tid = threadIdx .x ;
1143
+ const int il = tid/8 ;
1144
+ const int ir = tid%8 ;
1145
+ const int ib = 8 *i + ir;
1146
+ if (ib >= nb32) {
1147
+ return ;
1148
+ }
1149
+
1150
+ dst_t * y = yy + 256 *i + 32 *ir + 4 *il;
1151
+
1152
+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
1153
+ const float2 d = __half22float2 (x->dm );
1154
+
1155
+ const uint8_t * q = x->qs + 4 *il;
1156
+
1157
+ for (int l = 0 ; l < 4 ; ++l) {
1158
+ y[l+ 0 ] = d.x * (q[l] & 0xF ) + d.y ;
1159
+ y[l+16 ] = d.x * (q[l] >> 4 ) + d.y ;
1160
+ }
1161
+ }
1162
+
1108
1163
// ================================== k-quants
1109
1164
1110
1165
template <typename dst_t >
@@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
6253
6308
#endif
6254
6309
}
6255
6310
6311
+ template <typename dst_t >
6312
+ static void dequantize_q4_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6313
+ const int nb32 = k / 32 ;
6314
+ const int nb = (k + 255 ) / 256 ;
6315
+ dequantize_block_q4_0<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
6316
+ }
6317
+
6318
+ template <typename dst_t >
6319
+ static void dequantize_q4_1_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6320
+ const int nb32 = k / 32 ;
6321
+ const int nb = (k + 255 ) / 256 ;
6322
+ dequantize_block_q4_1<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
6323
+ }
6324
+
6256
6325
template <typename dst_t >
6257
6326
static void dequantize_row_q4_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6258
6327
const int nb = k / QK_K;
@@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6301
6370
int id;
6302
6371
switch (type) {
6303
6372
case GGML_TYPE_Q4_0:
6304
- return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0> ;
6373
+ return dequantize_q4_0_cuda ;
6305
6374
case GGML_TYPE_Q4_1:
6306
- return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1> ;
6375
+ return dequantize_q4_1_cuda ;
6307
6376
case GGML_TYPE_Q5_0:
6308
6377
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
6309
6378
case GGML_TYPE_Q5_1:
@@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6338
6407
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
6339
6408
switch (type) {
6340
6409
case GGML_TYPE_Q4_0:
6341
- return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0> ;
6410
+ return dequantize_q4_0_cuda ;
6342
6411
case GGML_TYPE_Q4_1:
6343
- return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1> ;
6412
+ return dequantize_q4_1_cuda ;
6344
6413
case GGML_TYPE_Q5_0:
6345
6414
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
6346
6415
case GGML_TYPE_Q5_1:
0 commit comments