Skip to content

Commit d232e0b

Browse files
committed
ggml : fix dot product (q2_k)
ggml-ci
1 parent 8fd9794 commit d232e0b

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ggml-quants.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,7 +3663,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
36633663
float sum = 0;
36643664

36653665
for (int i = 0; i < nb; ++i) {
3666-
36673666
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
36683667
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
36693668

@@ -3694,13 +3693,17 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
36943693
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
36953694
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
36963695
#else
3697-
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
3698-
{\
3699-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
3700-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
3701-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
3702-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
3703-
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
3696+
#define MULTIPLY_ACCUM_WITH_SCALE(index) \
3697+
{ \
3698+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])); \
3699+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])); \
3700+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])); \
3701+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])); \
3702+
\
3703+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); \
3704+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); \
3705+
\
3706+
isum += vaddvq_s32(p0) * aux[is+(index)] + vaddvq_s32(p1) * aux[is+1+(index)]; \
37043707
}
37053708
#endif
37063709

@@ -3710,26 +3713,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
37103713
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
37113714
MULTIPLY_ACCUM_WITH_SCALE((index));
37123715

3713-
37143716
for (int j = 0; j < QK_K/128; ++j) {
3715-
37163717
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
37173718

37183719
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
37193720
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
37203721
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
3722+
37213723
MULTIPLY_ACCUM_WITH_SCALE(0);
37223724

37233725
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
3724-
37253726
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
3726-
37273727
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
37283728

37293729
is += 8;
37303730
}
3731-
sum += d * isum;
37323731

3732+
sum += d * isum;
37333733
}
37343734

37353735
*s = sum;

0 commit comments

Comments
 (0)