@@ -3663,7 +3663,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3663
3663
float sum = 0 ;
3664
3664
3665
3665
for (int i = 0 ; i < nb ; ++ i ) {
3666
-
3667
3666
const float d = y [i ].d * GGML_FP16_TO_FP32 (x [i ].d );
3668
3667
const float dmin = - y [i ].d * GGML_FP16_TO_FP32 (x [i ].dmin );
3669
3668
@@ -3694,13 +3693,17 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3694
3693
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3695
3694
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3696
3695
#else
3697
- #define MULTIPLY_ACCUM_WITH_SCALE (index )\
3698
- {\
3699
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
3700
- vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
3701
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
3702
- vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
3703
- isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
3696
+ #define MULTIPLY_ACCUM_WITH_SCALE (index ) \
3697
+ { \
3698
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])); \
3699
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])); \
3700
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])); \
3701
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])); \
3702
+ \
3703
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); \
3704
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); \
3705
+ \
3706
+ isum += vaddvq_s32(p0) * aux[is+(index)] + vaddvq_s32(p1) * aux[is+1+(index)]; \
3704
3707
}
3705
3708
#endif
3706
3709
@@ -3710,26 +3713,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3710
3713
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
3711
3714
MULTIPLY_ACCUM_WITH_SCALE((index));
3712
3715
3713
-
3714
3716
for (int j = 0 ; j < QK_K /128 ; ++ j ) {
3715
-
3716
3717
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2 (q2 ); q2 += 32 ;
3717
3718
3718
3719
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2 (q8 ); q8 += 32 ;
3719
3720
q2bytes .val [0 ] = vreinterpretq_s8_u8 (vandq_u8 (q2bits .val [0 ], m3 ));
3720
3721
q2bytes .val [1 ] = vreinterpretq_s8_u8 (vandq_u8 (q2bits .val [1 ], m3 ));
3722
+
3721
3723
MULTIPLY_ACCUM_WITH_SCALE (0 );
3722
3724
3723
3725
SHIFT_MULTIPLY_ACCUM_WITH_SCALE (2 , 2 );
3724
-
3725
3726
SHIFT_MULTIPLY_ACCUM_WITH_SCALE (4 , 4 );
3726
-
3727
3727
SHIFT_MULTIPLY_ACCUM_WITH_SCALE (6 , 6 );
3728
3728
3729
3729
is += 8 ;
3730
3730
}
3731
- sum += d * isum ;
3732
3731
3732
+ sum += d * isum ;
3733
3733
}
3734
3734
3735
3735
* s = sum ;
0 commit comments