@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657657#define QK8_0 32
658658typedef struct {
659659 float d ; // delta
660+ float s ; // d * sum(qs[i])
660661 int8_t qs [QK8_0 ]; // quants
661662} block_q8_0 ;
662- static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663+ static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663664
664665
665666// reference implementation for deterministic creation of model files
@@ -1299,12 +1300,38 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300
13001301 y [i ].d = d ;
13011302
1303+ int sum = 0 ;
13021304 for (int l = 0 ; l < QK8_0 ; ++ l ) {
13031305 const float v = x [i * QK8_0 + l ]* id ;
13041306 y [i ].qs [l ] = roundf (v );
1305- }
1306- }
1307+ sum += y [i ].qs [l ];
1308+ }
1309+ y [i ].s = d * sum ;
1310+ }
1311+ }
1312+
1313+ #ifdef __AVX2__
1314+ // There is no better way of doing this?
1315+ // I guess not, AVX is not very good at horizontal sums.
1316+ // The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317+ // faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318+ // keeping the original.
1319+ // TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320+ //static inline float horizontal_sum(__m256i a) {
1321+ // __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322+ // __m256i sum = _mm256_add_epi32(a, b);
1323+ // __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324+ // sum = _mm256_add_epi32(sum, hi);
1325+ // return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326+ //}
1327+ static inline float horizontal_sum (__m256i a ) {
1328+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (a ), _mm256_extracti128_si256 (a , 1 ));
1329+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1330+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1331+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1332+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
13071333}
1334+ #endif
13081335
13091336static void quantize_row_q8_0 (const float * restrict x , void * restrict vy , int k ) {
13101337 assert (k % QK8_0 == 0 );
@@ -1332,6 +1359,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13321359
13331360 y [i ].d = d ;
13341361
1362+ int32x4_t accv = vdupq_n_s32 (0 );
1363+
13351364 for (int l = 0 ; l < 8 ; l ++ ) {
13361365 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
13371366 const int32x4_t vi = vcvtnq_s32_f32 (v );
@@ -1340,7 +1369,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13401369 y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
13411370 y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
13421371 y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1372+
1373+ accv = vaddq_s32 (accv , vi );
13431374 }
1375+ int32_t sum = vaddvq_s32 (accv );
1376+ y [i ].s = d * sum ;
13441377 }
13451378#elif defined(__AVX2__ ) || defined(__AVX__ )
13461379 for (int i = 0 ; i < nb ; i ++ ) {
@@ -1388,6 +1421,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13881421 __m256i i3 = _mm256_cvtps_epi32 ( v3 );
13891422
13901423#if defined(__AVX2__ )
1424+
1425+ // Compute the sum of the quants and set y[i].s
1426+ y [i ].s = d * horizontal_sum (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
1427+
13911428 // Convert int32 to int16
13921429 i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13931430 i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1467,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
14301467 // scalar
14311468 quantize_row_q8_0_reference (x , y , k );
14321469#endif
1470+ #if defined __AVX__
1471+ // TODO: vectorize this
1472+ for (int i = 0 ; i < nb ; ++ i ) {
1473+ int sum = 0 ;
1474+ for (int l = 0 ; l < QK8_0 ; ++ l ) sum += y [i ].qs [l ];
1475+ y [i ].s = y [i ].d * sum ;
1476+ }
1477+ #endif
14331478}
14341479
14351480static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2372,14 +2417,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23722417 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
23732418 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
23742419
2420+ float sum8 = 0 ;
2421+
23752422 for (int i = 0 ; i < nb ; i += 2 ) {
23762423 const block_q4_0 * restrict x0 = & x [i + 0 ];
23772424 const block_q4_0 * restrict x1 = & x [i + 1 ];
23782425 const block_q8_0 * restrict y0 = & y [i + 0 ];
23792426 const block_q8_0 * restrict y1 = & y [i + 1 ];
23802427
2428+ sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2429+
23812430 const uint8x16_t m4b = vdupq_n_u8 (0xf );
2382- const int8x16_t s8b = vdupq_n_s8 (0x8 );
23832431
23842432 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
23852433 const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2390,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23902438 const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
23912439 const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
23922440
2393- // sub 8
2394- const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2395- const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2396- const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2397- const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2398-
23992441 // load y
24002442 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
24012443 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2410,21 +2452,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24102452
24112453#if defined(__ARM_FEATURE_DOTPROD )
24122454 // dot product into int32x4_t
2413- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2414- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2455+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2456+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
24152457
24162458 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
24172459 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
24182460#else
2419- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2420- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2421- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2422- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2461+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2462+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2463+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2464+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
24232465
2424- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2425- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2426- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2427- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2466+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2467+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2468+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2469+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
24282470
24292471 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
24302472 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362478#endif
24372479 }
24382480
2439- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2481+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
24402482#elif defined(__AVX2__ )
24412483 // Initialize accumulator with zeros
24422484 __m256 acc = _mm256_setzero_ps ();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25692611 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
25702612 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
25712613
2614+ float summs = 0 ;
2615+
25722616 for (int i = 0 ; i < nb ; i += 2 ) {
25732617 const block_q4_1 * restrict x0 = & x [i + 0 ];
25742618 const block_q4_1 * restrict x1 = & x [i + 1 ];
25752619 const block_q8_0 * restrict y0 = & y [i + 0 ];
25762620 const block_q8_0 * restrict y1 = & y [i + 1 ];
25772621
2622+ summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2623+
25782624 const uint8x16_t m4b = vdupq_n_u8 (0xf );
25792625
25802626 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
@@ -2598,17 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25982644 const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
25992645 const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
26002646
2601- const int16x8_t s0i = vaddq_s16 (
2602- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2603- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2604-
2605- const int16x8_t s1i = vaddq_s16 (
2606- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2607- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2608-
2609- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2610- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2611-
26122647#if defined(__ARM_FEATURE_DOTPROD )
26132648 // dot product into int32x4_t
26142649 const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
@@ -2637,24 +2672,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372672#endif
26382673 }
26392674
2640- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2675+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
26412676#elif defined(__AVX2__ )
26422677 // Initialize accumulator with zeros
26432678 __m256 acc = _mm256_setzero_ps ();
26442679
2680+ float summs = 0 ;
2681+
26452682 // Main loop
26462683 for (int i = 0 ; i < nb ; ++ i ) {
26472684 const float * d0 = & x [i ].d ;
26482685 const float * d1 = & y [i ].d ;
2649- const float * m0 = & x [i ].m ;
2686+ //const float * m0 = &x[i].m;
2687+
2688+ summs += x [i ].m * y [i ].s ;
26502689
26512690 const __m256 d0v = _mm256_broadcast_ss ( d0 );
26522691 const __m256 d1v = _mm256_broadcast_ss ( d1 );
2653- const __m256 m0v = _mm256_broadcast_ss ( m0 );
26542692
26552693 // Compute combined scales
26562694 const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2657- const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
26582695
26592696 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
26602697 const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2676,15 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26762713
26772714 // Accumulate d0*d1*x*y
26782715 acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2679-
2680- // Compute sum of y values
2681- const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2682- const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2685-
2686- // Accumulate d1*m0*y
2687- acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
26882716 }
26892717
26902718 // Return horizontal sum of the acc vector
@@ -2693,7 +2721,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26932721 res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
26942722 res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
26952723
2696- sumf = _mm_cvtss_f32 ( res );
2724+ sumf = _mm_cvtss_f32 ( res ) + summs ;
26972725#else
26982726 // scalar
26992727 for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments