Skip to content

Commit ac0cd25

Browse files
authored
Adding SSE instructions to ggml_vec_dot_q4_0_q8_0 (#1413)
1 parent 0cd22e1 commit ac0cd25

File tree

1 file changed

+134
-1
lines changed

1 file changed

+134
-1
lines changed

ggml.c

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
472472
// quantization
473473
//
474474

475-
#if __AVX__ || __AVX2__ || __AVX512F__
475+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
476476
// multiply int8_t, add results pairwise twice
477477
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
478478
// Get absolute values of x vectors
@@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
485485
return _mm_madd_epi16(ones, dot);
486486
}
487487

488+
#if __AVX__ || __AVX2__ || __AVX512F__
488489
// horizontally add 8 floats
489490
static inline float hsum_float_8(const __m256 x) {
490491
__m128 res = _mm256_extractf128_ps(x, 1);
@@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
596597
return _mm_packus_epi16( bytes1, bytes2);
597598
}
598599
#endif
600+
#elif defined(__SSSE3__)
601+
// horizontally add 4x4 floats
602+
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
603+
__m128 res_0 =_mm_hadd_ps(a, b);
604+
__m128 res_1 =_mm_hadd_ps(c, d);
605+
__m128 res =_mm_hadd_ps(res_0, res_1);
606+
res =_mm_hadd_ps(res, res);
607+
res =_mm_hadd_ps(res, res);
608+
609+
return _mm_cvtss_f32(res);
610+
}
599611
#endif // __AVX__ || __AVX2__ || __AVX512F__
612+
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
600613

601614
#if __ARM_NEON
602615

@@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
21292142
}
21302143

21312144
*s = hsum_float_8(acc);
2145+
#elif defined(__SSSE3__)
2146+
// set constants
2147+
const __m128i lowMask = _mm_set1_epi8(0xF);
2148+
const __m128i off = _mm_set1_epi8(8);
2149+
2150+
// Initialize accumulator with zeros
2151+
__m128 acc_0 = _mm_setzero_ps();
2152+
__m128 acc_1 = _mm_setzero_ps();
2153+
__m128 acc_2 = _mm_setzero_ps();
2154+
__m128 acc_3 = _mm_setzero_ps();
2155+
2156+
// First round without accumulation
2157+
{
2158+
_mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
2159+
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2160+
2161+
// Compute combined scale for the block 0 and 1
2162+
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
2163+
2164+
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2165+
2166+
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2167+
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
2168+
bx_0 = _mm_sub_epi8(bx_0, off);
2169+
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2170+
2171+
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2172+
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
2173+
bx_1 = _mm_sub_epi8(bx_1, off);
2174+
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2175+
2176+
_mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
2177+
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2178+
2179+
// Compute combined scale for the block 2 and 3
2180+
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
2181+
2182+
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2183+
2184+
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2185+
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
2186+
bx_2 = _mm_sub_epi8(bx_2, off);
2187+
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2188+
2189+
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2190+
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
2191+
bx_3 = _mm_sub_epi8(bx_3, off);
2192+
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2193+
2194+
// Convert int32_t to float
2195+
__m128 p0 = _mm_cvtepi32_ps(i32_0);
2196+
__m128 p1 = _mm_cvtepi32_ps(i32_1);
2197+
__m128 p2 = _mm_cvtepi32_ps(i32_2);
2198+
__m128 p3 = _mm_cvtepi32_ps(i32_3);
2199+
2200+
// Apply the scale
2201+
acc_0 = _mm_mul_ps( d_0_1, p0 );
2202+
acc_1 = _mm_mul_ps( d_0_1, p1 );
2203+
acc_2 = _mm_mul_ps( d_2_3, p2 );
2204+
acc_3 = _mm_mul_ps( d_2_3, p3 );
2205+
}
2206+
2207+
// Main loop
2208+
for (int i = 2; i < nb; i+=2) {
2209+
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
2210+
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2211+
2212+
// Compute combined scale for the block 0 and 1
2213+
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
2214+
2215+
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2216+
2217+
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2218+
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
2219+
bx_0 = _mm_sub_epi8(bx_0, off);
2220+
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2221+
2222+
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2223+
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2224+
bx_1 = _mm_sub_epi8(bx_1, off);
2225+
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2226+
2227+
_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
2228+
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2229+
2230+
// Compute combined scale for the block 2 and 3
2231+
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
2232+
2233+
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2234+
2235+
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2236+
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
2237+
bx_2 = _mm_sub_epi8(bx_2, off);
2238+
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2239+
2240+
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2241+
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
2242+
bx_3 = _mm_sub_epi8(bx_3, off);
2243+
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2244+
2245+
// Convert int32_t to float
2246+
__m128 p0 = _mm_cvtepi32_ps(i32_0);
2247+
__m128 p1 = _mm_cvtepi32_ps(i32_1);
2248+
__m128 p2 = _mm_cvtepi32_ps(i32_2);
2249+
__m128 p3 = _mm_cvtepi32_ps(i32_3);
2250+
2251+
// Apply the scale
2252+
__m128 p0_d = _mm_mul_ps( d_0_1, p0 );
2253+
__m128 p1_d = _mm_mul_ps( d_0_1, p1 );
2254+
__m128 p2_d = _mm_mul_ps( d_2_3, p2 );
2255+
__m128 p3_d = _mm_mul_ps( d_2_3, p3 );
2256+
2257+
// Acummulate
2258+
acc_0 = _mm_add_ps(p0_d, acc_0);
2259+
acc_1 = _mm_add_ps(p1_d, acc_1);
2260+
acc_2 = _mm_add_ps(p2_d, acc_2);
2261+
acc_3 = _mm_add_ps(p3_d, acc_3);
2262+
}
2263+
2264+
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
21322265
#else
21332266
// scalar
21342267
float sumf = 0.0;

0 commit comments

Comments
 (0)