@@ -49,13 +49,6 @@ typedef struct {
49
49
} block_q4_1;
50
50
static_assert (sizeof (block_q4_1) == sizeof(float ) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
51
51
52
- #define QK4_2 16
53
- typedef struct {
54
- half d; // delta
55
- uint8_t qs[QK4_2 / 2 ]; // nibbles / quants
56
- } block_q4_2;
57
- static_assert (sizeof (block_q4_2) == sizeof(ggml_fp16_t ) + QK4_2 / 2, "wrong q4_2 block size/padding");
58
-
59
52
#define QK5_0 32
60
53
typedef struct {
61
54
half d; // delta
@@ -81,147 +74,102 @@ typedef struct {
81
74
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
82
75
83
76
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
77
+ static const int qk = QK4_0;
78
+
84
79
const block_q4_0 * x = (const block_q4_0 *) vx;
85
80
86
81
const int i = blockIdx .x ;
87
82
88
83
const float d = x[i].d ;
89
84
90
- const uint8_t * pp = x[i].qs ;
91
-
92
- for (int l = 0 ; l < QK4_0; l += 2 ) {
93
- const uint8_t vi = pp[l/2 ];
94
-
95
- const int8_t vi0 = vi & 0xf ;
96
- const int8_t vi1 = vi >> 4 ;
85
+ for (int j = 0 ; j < qk/2 ; ++j) {
86
+ const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
87
+ const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
97
88
98
- const float v0 = (vi0 - 8 )*d;
99
- const float v1 = (vi1 - 8 )*d;
100
-
101
- y[i*QK4_0 + l + 0 ] = v0;
102
- y[i*QK4_0 + l + 1 ] = v1;
89
+ y[i*qk + j + 0 ] = x0*d;
90
+ y[i*qk + j + qk/2 ] = x1*d;
103
91
}
104
92
}
105
93
106
94
static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
95
+ static const int qk = QK4_1;
96
+
107
97
const block_q4_1 * x = (const block_q4_1 *) vx;
108
98
109
99
const int i = blockIdx .x ;
110
100
111
101
const float d = x[i].d ;
112
102
const float m = x[i].m ;
113
103
114
- const uint8_t * pp = x[i].qs ;
115
-
116
- for (int l = 0 ; l < QK4_1; l += 2 ) {
117
- const uint8_t vi = pp[l/2 ];
118
-
119
- const int8_t vi0 = vi & 0xf ;
120
- const int8_t vi1 = vi >> 4 ;
104
+ for (int j = 0 ; j < qk/2 ; ++j) {
105
+ const int x0 = (x[i].qs [j] & 0xf );
106
+ const int x1 = (x[i].qs [j] >> 4 );
121
107
122
- const float v0 = vi0*d + m;
123
- const float v1 = vi1*d + m;
124
-
125
- y[i*QK4_1 + l + 0 ] = v0;
126
- y[i*QK4_1 + l + 1 ] = v1;
127
- }
128
- }
129
-
130
- static __global__ void dequantize_block_q4_2 (const void * vx, float * y) {
131
- const block_q4_2 * x = (const block_q4_2 *) vx;
132
-
133
- const int i = blockIdx .x ;
134
-
135
- const float d = x[i].d ;
136
-
137
- const uint8_t * pp = x[i].qs ;
138
-
139
- for (int l = 0 ; l < QK4_2; l += 2 ) {
140
- const uint8_t vi = pp[l/2 ];
141
-
142
- const int8_t vi0 = vi & 0xf ;
143
- const int8_t vi1 = vi >> 4 ;
144
-
145
- const float v0 = (vi0 - 8 )*d;
146
- const float v1 = (vi1 - 8 )*d;
147
-
148
- y[i*QK4_2 + l + 0 ] = v0;
149
- y[i*QK4_2 + l + 1 ] = v1;
108
+ y[i*qk + j + 0 ] = x0*d + m;
109
+ y[i*qk + j + qk/2 ] = x1*d + m;
150
110
}
151
111
}
152
112
153
113
static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
114
+ static const int qk = QK5_0;
115
+
154
116
const block_q5_0 * x = (const block_q5_0 *) vx;
155
117
156
118
const int i = blockIdx .x ;
157
119
158
120
const float d = x[i].d ;
159
121
160
- const uint8_t * pp = x[i].qs ;
161
-
162
122
uint32_t qh;
163
123
memcpy (&qh, x[i].qh , sizeof (qh));
164
124
165
- for (int l = 0 ; l < QK5_0; l += 2 ) {
166
- const uint8_t vi = pp[l/2 ];
167
-
168
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
169
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
125
+ for (int j = 0 ; j < qk/2 ; ++j) {
126
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
127
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
170
128
171
- const int8_t vi0 = ((vi & 0xf ) | vh0) ;
172
- const int8_t vi1 = ((vi >> 4 ) | vh1) ;
129
+ const int32_t x0 = ((x[i]. qs [j] & 0xf ) | xh_0) - 16 ;
130
+ const int32_t x1 = ((x[i]. qs [j] >> 4 ) | xh_1) - 16 ;
173
131
174
- const float v0 = (vi0 - 16 )*d;
175
- const float v1 = (vi1 - 16 )*d;
176
-
177
- y[i*QK5_0 + l + 0 ] = v0;
178
- y[i*QK5_0 + l + 1 ] = v1;
132
+ y[i*qk + j + 0 ] = x0*d;
133
+ y[i*qk + j + qk/2 ] = x1*d;
179
134
}
180
135
}
181
136
182
137
static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
138
+ static const int qk = QK5_1;
139
+
183
140
const block_q5_1 * x = (const block_q5_1 *) vx;
184
141
185
142
const int i = blockIdx .x ;
186
143
187
144
const float d = x[i].d ;
188
145
const float m = x[i].m ;
189
146
190
- const uint8_t * pp = x[i].qs ;
191
-
192
147
uint32_t qh;
193
148
memcpy (&qh, x[i].qh , sizeof (qh));
194
149
195
- for (int l = 0 ; l < QK5_1; l += 2 ) {
196
- const uint8_t vi = pp[l/2 ];
197
-
198
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
199
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
150
+ for (int j = 0 ; j < qk/2 ; ++j) {
151
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
152
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
200
153
201
- const int8_t vi0 = (vi & 0xf ) | vh0 ;
202
- const int8_t vi1 = (vi >> 4 ) | vh1 ;
154
+ const int x0 = (x[i]. qs [j] & 0xf ) | xh_0 ;
155
+ const int x1 = (x[i]. qs [j] >> 4 ) | xh_1 ;
203
156
204
- const float v0 = vi0*d + m;
205
- const float v1 = vi1*d + m;
206
-
207
- y[i*QK5_1 + l + 0 ] = v0;
208
- y[i*QK5_1 + l + 1 ] = v1;
157
+ y[i*qk + j + 0 ] = x0*d + m;
158
+ y[i*qk + j + qk/2 ] = x1*d + m;
209
159
}
210
160
}
211
161
212
162
static __global__ void dequantize_block_q8_0 (const void * vx, float * y) {
163
+ static const int qk = QK8_0;
164
+
213
165
const block_q8_0 * x = (const block_q8_0 *) vx;
214
166
215
167
const int i = blockIdx .x ;
216
168
217
169
const float d = x[i].d ;
218
170
219
- const int8_t * pp = x[i].qs ;
220
-
221
- for (int l = 0 ; l < QK8_0; l++) {
222
- const int8_t vi = pp[l];
223
-
224
- y[i*QK8_0 + l] = vi*d;
171
+ for (int j = 0 ; j < qk; ++j) {
172
+ y[i*qk + j] = x[i].qs [j]*d;
225
173
}
226
174
}
227
175
@@ -235,11 +183,6 @@ static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStre
235
183
dequantize_block_q4_1<<<nb, 1 , 0 , stream>>> (vx, y);
236
184
}
237
185
238
- static void dequantize_row_q4_2_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
239
- const int nb = k / QK4_2;
240
- dequantize_block_q4_2<<<nb, 1 , 0 , stream>>> (vx, y);
241
- }
242
-
243
186
static void dequantize_row_q5_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
244
187
const int nb = k / QK5_0;
245
188
dequantize_block_q5_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -274,8 +217,6 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
274
217
return dequantize_row_q4_0_cuda;
275
218
case GGML_TYPE_Q4_1:
276
219
return dequantize_row_q4_1_cuda;
277
- case GGML_TYPE_Q4_2:
278
- return dequantize_row_q4_2_cuda;
279
220
case GGML_TYPE_Q5_0:
280
221
return dequantize_row_q5_0_cuda;
281
222
case GGML_TYPE_Q5_1:
0 commit comments