@@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
472472// quantization
473473//
474474
475- #if __AVX__ || __AVX2__ || __AVX512F__
475+ #if defined( __AVX__ ) || defined( __AVX2__ ) || defined( __AVX512F__ ) || defined( __SSSE3__ )
476476// multiply int8_t, add results pairwise twice
477477static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
478478 // Get absolute values of x vectors
@@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
485485 return _mm_madd_epi16 (ones , dot );
486486}
487487
488+ #if __AVX__ || __AVX2__ || __AVX512F__
488489// horizontally add 8 floats
489490static inline float hsum_float_8 (const __m256 x ) {
490491 __m128 res = _mm256_extractf128_ps (x , 1 );
@@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
596597 return _mm_packus_epi16 ( bytes1 , bytes2 );
597598}
598599#endif
600+ #elif defined(__SSSE3__ )
601+ // horizontally add 4x4 floats
602+ static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
603+ __m128 res_0 = _mm_hadd_ps (a , b );
604+ __m128 res_1 = _mm_hadd_ps (c , d );
605+ __m128 res = _mm_hadd_ps (res_0 , res_1 );
606+ res = _mm_hadd_ps (res , res );
607+ res = _mm_hadd_ps (res , res );
608+
609+ return _mm_cvtss_f32 (res );
610+ }
599611#endif // __AVX__ || __AVX2__ || __AVX512F__
612+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
600613
601614#if __ARM_NEON
602615
@@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
21292142 }
21302143
21312144 * s = hsum_float_8 (acc );
2145+ #elif defined(__SSSE3__ )
2146+ // set constants
2147+ const __m128i lowMask = _mm_set1_epi8 (0xF );
2148+ const __m128i off = _mm_set1_epi8 (8 );
2149+
2150+ // Initialize accumulator with zeros
2151+ __m128 acc_0 = _mm_setzero_ps ();
2152+ __m128 acc_1 = _mm_setzero_ps ();
2153+ __m128 acc_2 = _mm_setzero_ps ();
2154+ __m128 acc_3 = _mm_setzero_ps ();
2155+
2156+ // First round without accumulation
2157+ {
2158+ _mm_prefetch (& x [0 ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2159+ _mm_prefetch (& y [0 ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2160+
2161+ // Compute combined scale for the block 0 and 1
2162+ const __m128 d_0_1 = _mm_mul_ps ( _mm_set1_ps ( x [0 ].d ), _mm_set1_ps ( y [0 ].d ) );
2163+
2164+ const __m128i tmp_0_1 = _mm_loadu_si128 ((const __m128i * )x [0 ].qs );
2165+
2166+ __m128i bx_0 = _mm_and_si128 (lowMask , tmp_0_1 );
2167+ __m128i by_0 = _mm_loadu_si128 ((const __m128i * )y [0 ].qs );
2168+ bx_0 = _mm_sub_epi8 (bx_0 , off );
2169+ const __m128i i32_0 = mul_sum_i8_pairs (bx_0 , by_0 );
2170+
2171+ __m128i bx_1 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_0_1 , 4 ));
2172+ __m128i by_1 = _mm_loadu_si128 ((const __m128i * )(y [0 ].qs + 16 ));
2173+ bx_1 = _mm_sub_epi8 (bx_1 , off );
2174+ const __m128i i32_1 = mul_sum_i8_pairs (bx_1 , by_1 );
2175+
2176+ _mm_prefetch (& x [1 ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2177+ _mm_prefetch (& y [1 ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2178+
2179+ // Compute combined scale for the block 2 and 3
2180+ const __m128 d_2_3 = _mm_mul_ps ( _mm_set1_ps ( x [1 ].d ), _mm_set1_ps ( y [1 ].d ) );
2181+
2182+ const __m128i tmp_2_3 = _mm_loadu_si128 ((const __m128i * )x [1 ].qs );
2183+
2184+ __m128i bx_2 = _mm_and_si128 (lowMask , tmp_2_3 );
2185+ __m128i by_2 = _mm_loadu_si128 ((const __m128i * )y [1 ].qs );
2186+ bx_2 = _mm_sub_epi8 (bx_2 , off );
2187+ const __m128i i32_2 = mul_sum_i8_pairs (bx_2 , by_2 );
2188+
2189+ __m128i bx_3 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_2_3 , 4 ));
2190+ __m128i by_3 = _mm_loadu_si128 ((const __m128i * )(y [1 ].qs + 16 ));
2191+ bx_3 = _mm_sub_epi8 (bx_3 , off );
2192+ const __m128i i32_3 = mul_sum_i8_pairs (bx_3 , by_3 );
2193+
2194+ // Convert int32_t to float
2195+ __m128 p0 = _mm_cvtepi32_ps (i32_0 );
2196+ __m128 p1 = _mm_cvtepi32_ps (i32_1 );
2197+ __m128 p2 = _mm_cvtepi32_ps (i32_2 );
2198+ __m128 p3 = _mm_cvtepi32_ps (i32_3 );
2199+
2200+ // Apply the scale
2201+ acc_0 = _mm_mul_ps ( d_0_1 , p0 );
2202+ acc_1 = _mm_mul_ps ( d_0_1 , p1 );
2203+ acc_2 = _mm_mul_ps ( d_2_3 , p2 );
2204+ acc_3 = _mm_mul_ps ( d_2_3 , p3 );
2205+ }
2206+
2207+ // Main loop
2208+ for (int i = 2 ; i < nb ; i += 2 ) {
2209+ _mm_prefetch (& x [i ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2210+ _mm_prefetch (& y [i ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2211+
2212+ // Compute combined scale for the block 0 and 1
2213+ const __m128 d_0_1 = _mm_mul_ps ( _mm_set1_ps ( x [i ].d ), _mm_set1_ps ( y [i ].d ) );
2214+
2215+ const __m128i tmp_0_1 = _mm_loadu_si128 ((const __m128i * )x [i ].qs );
2216+
2217+ __m128i bx_0 = _mm_and_si128 (lowMask , tmp_0_1 );
2218+ __m128i by_0 = _mm_loadu_si128 ((const __m128i * )y [i ].qs );
2219+ bx_0 = _mm_sub_epi8 (bx_0 , off );
2220+ const __m128i i32_0 = mul_sum_i8_pairs (bx_0 , by_0 );
2221+
2222+ __m128i bx_1 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_0_1 , 4 ));
2223+ __m128i by_1 = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 ));
2224+ bx_1 = _mm_sub_epi8 (bx_1 , off );
2225+ const __m128i i32_1 = mul_sum_i8_pairs (bx_1 , by_1 );
2226+
2227+ _mm_prefetch (& x [i ] + 2 * sizeof (block_q4_0 ), _MM_HINT_T0 );
2228+ _mm_prefetch (& y [i ] + 2 * sizeof (block_q8_0 ), _MM_HINT_T0 );
2229+
2230+ // Compute combined scale for the block 2 and 3
2231+ const __m128 d_2_3 = _mm_mul_ps ( _mm_set1_ps ( x [i + 1 ].d ), _mm_set1_ps ( y [i + 1 ].d ) );
2232+
2233+ const __m128i tmp_2_3 = _mm_loadu_si128 ((const __m128i * )x [i + 1 ].qs );
2234+
2235+ __m128i bx_2 = _mm_and_si128 (lowMask , tmp_2_3 );
2236+ __m128i by_2 = _mm_loadu_si128 ((const __m128i * )y [i + 1 ].qs );
2237+ bx_2 = _mm_sub_epi8 (bx_2 , off );
2238+ const __m128i i32_2 = mul_sum_i8_pairs (bx_2 , by_2 );
2239+
2240+ __m128i bx_3 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_2_3 , 4 ));
2241+ __m128i by_3 = _mm_loadu_si128 ((const __m128i * )(y [i + 1 ].qs + 16 ));
2242+ bx_3 = _mm_sub_epi8 (bx_3 , off );
2243+ const __m128i i32_3 = mul_sum_i8_pairs (bx_3 , by_3 );
2244+
2245+ // Convert int32_t to float
2246+ __m128 p0 = _mm_cvtepi32_ps (i32_0 );
2247+ __m128 p1 = _mm_cvtepi32_ps (i32_1 );
2248+ __m128 p2 = _mm_cvtepi32_ps (i32_2 );
2249+ __m128 p3 = _mm_cvtepi32_ps (i32_3 );
2250+
2251+ // Apply the scale
2252+ __m128 p0_d = _mm_mul_ps ( d_0_1 , p0 );
2253+ __m128 p1_d = _mm_mul_ps ( d_0_1 , p1 );
2254+ __m128 p2_d = _mm_mul_ps ( d_2_3 , p2 );
2255+ __m128 p3_d = _mm_mul_ps ( d_2_3 , p3 );
2256+
2257+ // Acummulate
2258+ acc_0 = _mm_add_ps (p0_d , acc_0 );
2259+ acc_1 = _mm_add_ps (p1_d , acc_1 );
2260+ acc_2 = _mm_add_ps (p2_d , acc_2 );
2261+ acc_3 = _mm_add_ps (p3_d , acc_3 );
2262+ }
2263+
2264+ * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
21322265#else
21332266 // scalar
21342267 float sumf = 0.0 ;
0 commit comments