Skip to content

ggml : add AVX support based on AVX2 code #1430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 132 additions & 3 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,63 @@ static inline __m128i packNibbles( __m256i bytes )
return _mm_packus_epi16( r0, r1 );
#endif
}
#else
#elif defined(__AVX__)
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytesl = _mm_or_si128(bytesl, bit_mask);
bytesh = _mm_or_si128(bytesh, bit_mask);
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
return _mm256_set_m128i(bytesh, bytesl);
}

// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
__m128i tmph = _mm_srli_epi16(tmpl, 4);
const __m128i lowMask = _mm_set1_epi8(0xF);
tmpl = _mm_and_si128(lowMask, tmpl);
tmph = _mm_and_si128(lowMask, tmph);
return _mm256_set_m128i(tmph, tmpl);
}

// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
const __m128i ones = _mm_set1_epi16(1);
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
return _mm256_cvtepi32_ps(summed_pairs);
}

// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
const __m128i xh = _mm256_extractf128_si256(x, 1);
const __m128i yl = _mm256_castsi256_si128(y);
const __m128i yh = _mm256_extractf128_si256(y, 1);
// Get absolute values of x vectors
const __m128i axl = _mm_sign_epi8(xl, xl);
const __m128i axh = _mm_sign_epi8(xh, xh);
// Sign the values of the y vectors
const __m128i syl = _mm_sign_epi8(yl, xl);
const __m128i syh = _mm_sign_epi8(yh, xh);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}

static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
Expand Down Expand Up @@ -2221,7 +2277,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
}

*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__)
#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

Expand All @@ -2247,7 +2303,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const __m256 xy = mul_sum_i8_pairs_float(bx, by);

// Accumulate d0*d1*x*y
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d0d1, xy, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
#endif
}

*s = hsum_float_8(acc) + summs;
Expand Down Expand Up @@ -2458,6 +2518,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(d, q, acc);
}

*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8((char)0xF0);

// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));

__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, mask);
bxhih = _mm_andnot_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);

const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

const __m256 q = mul_sum_i8_pairs_float(bx, by);

/* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
}

*s = hsum_float_8(acc);
#else
// scalar
Expand Down Expand Up @@ -2686,6 +2777,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
}

*s = hsum_float_8(acc) + summs;
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8(0x10);

float summs = 0.0f;

// Main loop
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));

summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;

__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, mask);
bxhih = _mm_and_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);

const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

const __m256 q = mul_sum_i8_pairs_float(bx, by);

acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
}

*s = hsum_float_8(acc) + summs;
#else
// scalar
Expand Down Expand Up @@ -2776,7 +2901,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
}

*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

Expand All @@ -2790,7 +2915,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
const __m256 q = mul_sum_i8_pairs_float(bx, by);

// Multiply q with scale and accumulate
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d, q, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
#endif
}

*s = hsum_float_8(acc);
Expand Down