@@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
472
472
// quantization
473
473
//
474
474
475
- #if __AVX__ || __AVX2__ || __AVX512F__
475
+ #if defined( __AVX__ ) || defined( __AVX2__ ) || defined( __AVX512F__ ) || defined( __SSSE3__ )
476
476
// multiply int8_t, add results pairwise twice
477
477
static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
478
478
// Get absolute values of x vectors
@@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
485
485
return _mm_madd_epi16 (ones , dot );
486
486
}
487
487
488
+ #if __AVX__ || __AVX2__ || __AVX512F__
488
489
// horizontally add 8 floats
489
490
static inline float hsum_float_8 (const __m256 x ) {
490
491
__m128 res = _mm256_extractf128_ps (x , 1 );
@@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
596
597
return _mm_packus_epi16 ( bytes1 , bytes2 );
597
598
}
598
599
#endif
600
+ #elif defined(__SSSE3__ )
601
+ // horizontally add 4x4 floats
602
+ static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
603
+ __m128 res_0 = _mm_hadd_ps (a , b );
604
+ __m128 res_1 = _mm_hadd_ps (c , d );
605
+ __m128 res = _mm_hadd_ps (res_0 , res_1 );
606
+ res = _mm_hadd_ps (res , res );
607
+ res = _mm_hadd_ps (res , res );
608
+
609
+ return _mm_cvtss_f32 (res );
610
+ }
599
611
#endif // __AVX__ || __AVX2__ || __AVX512F__
612
+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
600
613
601
614
#if __ARM_NEON
602
615
@@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2129
2142
}
2130
2143
2131
2144
* s = hsum_float_8 (acc );
2145
+ #elif defined(__SSSE3__ )
2146
+ // set constants
2147
+ const __m128i lowMask = _mm_set1_epi8 (0xF );
2148
+ const __m128i off = _mm_set1_epi8 (8 );
2149
+
2150
+ // Initialize accumulator with zeros
2151
+ __m128 acc_0 = _mm_setzero_ps ();
2152
+ __m128 acc_1 = _mm_setzero_ps ();
2153
+ __m128 acc_2 = _mm_setzero_ps ();
2154
+ __m128 acc_3 = _mm_setzero_ps ();
2155
+
2156
+ // First round without accumulation
2157
+ {
2158
+ _mm_prefetch (& x [0 ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2159
+ _mm_prefetch (& y [0 ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2160
+
2161
+ // Compute combined scale for the block 0 and 1
2162
+ const __m128 d_0_1 = _mm_mul_ps ( _mm_set1_ps ( x [0 ].d ), _mm_set1_ps ( y [0 ].d ) );
2163
+
2164
+ const __m128i tmp_0_1 = _mm_loadu_si128 ((const __m128i * )x [0 ].qs );
2165
+
2166
+ __m128i bx_0 = _mm_and_si128 (lowMask , tmp_0_1 );
2167
+ __m128i by_0 = _mm_loadu_si128 ((const __m128i * )y [0 ].qs );
2168
+ bx_0 = _mm_sub_epi8 (bx_0 , off );
2169
+ const __m128i i32_0 = mul_sum_i8_pairs (bx_0 , by_0 );
2170
+
2171
+ __m128i bx_1 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_0_1 , 4 ));
2172
+ __m128i by_1 = _mm_loadu_si128 ((const __m128i * )(y [0 ].qs + 16 ));
2173
+ bx_1 = _mm_sub_epi8 (bx_1 , off );
2174
+ const __m128i i32_1 = mul_sum_i8_pairs (bx_1 , by_1 );
2175
+
2176
+ _mm_prefetch (& x [1 ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2177
+ _mm_prefetch (& y [1 ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2178
+
2179
+ // Compute combined scale for the block 2 and 3
2180
+ const __m128 d_2_3 = _mm_mul_ps ( _mm_set1_ps ( x [1 ].d ), _mm_set1_ps ( y [1 ].d ) );
2181
+
2182
+ const __m128i tmp_2_3 = _mm_loadu_si128 ((const __m128i * )x [1 ].qs );
2183
+
2184
+ __m128i bx_2 = _mm_and_si128 (lowMask , tmp_2_3 );
2185
+ __m128i by_2 = _mm_loadu_si128 ((const __m128i * )y [1 ].qs );
2186
+ bx_2 = _mm_sub_epi8 (bx_2 , off );
2187
+ const __m128i i32_2 = mul_sum_i8_pairs (bx_2 , by_2 );
2188
+
2189
+ __m128i bx_3 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_2_3 , 4 ));
2190
+ __m128i by_3 = _mm_loadu_si128 ((const __m128i * )(y [1 ].qs + 16 ));
2191
+ bx_3 = _mm_sub_epi8 (bx_3 , off );
2192
+ const __m128i i32_3 = mul_sum_i8_pairs (bx_3 , by_3 );
2193
+
2194
+ // Convert int32_t to float
2195
+ __m128 p0 = _mm_cvtepi32_ps (i32_0 );
2196
+ __m128 p1 = _mm_cvtepi32_ps (i32_1 );
2197
+ __m128 p2 = _mm_cvtepi32_ps (i32_2 );
2198
+ __m128 p3 = _mm_cvtepi32_ps (i32_3 );
2199
+
2200
+ // Apply the scale
2201
+ acc_0 = _mm_mul_ps ( d_0_1 , p0 );
2202
+ acc_1 = _mm_mul_ps ( d_0_1 , p1 );
2203
+ acc_2 = _mm_mul_ps ( d_2_3 , p2 );
2204
+ acc_3 = _mm_mul_ps ( d_2_3 , p3 );
2205
+ }
2206
+
2207
+ // Main loop
2208
+ for (int i = 2 ; i < nb ; i += 2 ) {
2209
+ _mm_prefetch (& x [i ] + sizeof (block_q4_0 ), _MM_HINT_T0 );
2210
+ _mm_prefetch (& y [i ] + sizeof (block_q8_0 ), _MM_HINT_T0 );
2211
+
2212
+ // Compute combined scale for the block 0 and 1
2213
+ const __m128 d_0_1 = _mm_mul_ps ( _mm_set1_ps ( x [i ].d ), _mm_set1_ps ( y [i ].d ) );
2214
+
2215
+ const __m128i tmp_0_1 = _mm_loadu_si128 ((const __m128i * )x [i ].qs );
2216
+
2217
+ __m128i bx_0 = _mm_and_si128 (lowMask , tmp_0_1 );
2218
+ __m128i by_0 = _mm_loadu_si128 ((const __m128i * )y [i ].qs );
2219
+ bx_0 = _mm_sub_epi8 (bx_0 , off );
2220
+ const __m128i i32_0 = mul_sum_i8_pairs (bx_0 , by_0 );
2221
+
2222
+ __m128i bx_1 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_0_1 , 4 ));
2223
+ __m128i by_1 = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 ));
2224
+ bx_1 = _mm_sub_epi8 (bx_1 , off );
2225
+ const __m128i i32_1 = mul_sum_i8_pairs (bx_1 , by_1 );
2226
+
2227
+ _mm_prefetch (& x [i ] + 2 * sizeof (block_q4_0 ), _MM_HINT_T0 );
2228
+ _mm_prefetch (& y [i ] + 2 * sizeof (block_q8_0 ), _MM_HINT_T0 );
2229
+
2230
+ // Compute combined scale for the block 2 and 3
2231
+ const __m128 d_2_3 = _mm_mul_ps ( _mm_set1_ps ( x [i + 1 ].d ), _mm_set1_ps ( y [i + 1 ].d ) );
2232
+
2233
+ const __m128i tmp_2_3 = _mm_loadu_si128 ((const __m128i * )x [i + 1 ].qs );
2234
+
2235
+ __m128i bx_2 = _mm_and_si128 (lowMask , tmp_2_3 );
2236
+ __m128i by_2 = _mm_loadu_si128 ((const __m128i * )y [i + 1 ].qs );
2237
+ bx_2 = _mm_sub_epi8 (bx_2 , off );
2238
+ const __m128i i32_2 = mul_sum_i8_pairs (bx_2 , by_2 );
2239
+
2240
+ __m128i bx_3 = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp_2_3 , 4 ));
2241
+ __m128i by_3 = _mm_loadu_si128 ((const __m128i * )(y [i + 1 ].qs + 16 ));
2242
+ bx_3 = _mm_sub_epi8 (bx_3 , off );
2243
+ const __m128i i32_3 = mul_sum_i8_pairs (bx_3 , by_3 );
2244
+
2245
+ // Convert int32_t to float
2246
+ __m128 p0 = _mm_cvtepi32_ps (i32_0 );
2247
+ __m128 p1 = _mm_cvtepi32_ps (i32_1 );
2248
+ __m128 p2 = _mm_cvtepi32_ps (i32_2 );
2249
+ __m128 p3 = _mm_cvtepi32_ps (i32_3 );
2250
+
2251
+ // Apply the scale
2252
+ __m128 p0_d = _mm_mul_ps ( d_0_1 , p0 );
2253
+ __m128 p1_d = _mm_mul_ps ( d_0_1 , p1 );
2254
+ __m128 p2_d = _mm_mul_ps ( d_2_3 , p2 );
2255
+ __m128 p3_d = _mm_mul_ps ( d_2_3 , p3 );
2256
+
2257
+ // Acummulate
2258
+ acc_0 = _mm_add_ps (p0_d , acc_0 );
2259
+ acc_1 = _mm_add_ps (p1_d , acc_1 );
2260
+ acc_2 = _mm_add_ps (p2_d , acc_2 );
2261
+ acc_3 = _mm_add_ps (p3_d , acc_3 );
2262
+ }
2263
+
2264
+ * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
2132
2265
#else
2133
2266
// scalar
2134
2267
float sumf = 0.0 ;
0 commit comments