Skip to content

Faster q3_0 implementation, using two planes #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 102 additions & 112 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,12 @@ typedef struct {
static_assert(sizeof(block_q2_0) == sizeof(ggml_fp16_t) + QK2_0 / 4, "wrong q2_0 size/padding");

#define QK3_0 16
typedef union {
struct {
uint16_t pad[3];
ggml_fp16_t d;
};
uint64_t qs;
typedef struct {
ggml_fp16_t d;
// Instead of representing q3_0 as a packed format "...210210210210",
// represent it as two planes: "...10101010" and "...2222"
uint16_t qhi; // The highest bit of each 3-bit number, packed together
uint32_t qlo; // The low 2-bits of each 3-bit number, packed together
} block_q3_0;
static_assert(sizeof(block_q3_0) == sizeof(ggml_fp16_t) + QK3_0 * 3 / 8, "wrong q3_0 size/padding");

Expand Down Expand Up @@ -691,17 +691,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
const float d = max / -4;
const float id = d ? 1.0f/d : 0.0f;

uint64_t qs = 0;
uint32_t lo = 0;
uint16_t hi = 0;

for (int l = 0; l < QK3_0; l++) {
const float v = x[i*QK3_0 + l]*id;
const uint8_t vi = MIN(7, (int8_t)roundf(v) + 4);
assert(vi < 8);
qs |= (uint64_t)vi << (l*3);
lo |= (vi & 3) << (l * 2);
hi |= ((vi >> 2) & 1) << l;
}

y[i].qs = qs;
y[i].d = GGML_FP32_TO_FP16(d); // overwrite unused part of uint64_t qs
y[i].d = GGML_FP32_TO_FP16(d);
y[i].qlo = lo;
y[i].qhi = hi;
}
}

Expand Down Expand Up @@ -1335,13 +1338,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in

for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
uint64_t qs = x[i].qs;
uint_fast32_t lo = x[i].qlo;
uint_fast32_t hi = x[i].qhi << 2;
for (int l = 0; l < QK3_0; l++) {
const int8_t vi = qs & 7;
const int8_t vi = (lo & 3) | (hi & 4);
const float v = (vi - 4)*d;
y[i*QK3_0 + l] = v;
assert(!isnan(y[i*QK3_0 + l]));
qs >>= 3;
lo >>= 2;
hi >>= 1;
}
}
}
Expand Down Expand Up @@ -2193,6 +2198,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf;
}

#if __AVX2__ || __AVX512F__
// Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
// converting the result to 32-bit floats packed into a 256-bit vector.
static inline __m256 dotMul(__m256i bx, __m256i by) {
# if __AVXVNNIINT8__
// Perform multiplication and sum to 32-bit values
const __m256i i32 = _mm256_dpbssd_epi32(bx, by, _mm256_setzero_si256());
# else
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

// Convert int16_t to int32_t by adding pairwise
const __m256i ones = _mm256_set1_epi16(1);
const __m256i i32 = _mm256_madd_epi16(ones, dot);
# endif
// Convert int32_t to float
return _mm256_cvtepi32_ps(i32);
}

// Return horizontal sum of 32-bit floats packed into a 256-bit vector.
static inline float horizontalSum(__m256 acc) {
__m128 res = _mm256_extractf128_ps(acc, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
return _mm_cvtss_f32(res);
}
#endif

static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK2_0 == 0);
const int nb = n / QK2_0;
Expand Down Expand Up @@ -2222,30 +2260,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
// Load y vector
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

// Convert int16_t to int32_t by adding pairwise
const __m256i ones = _mm256_set1_epi16(1);
__m256i i32 = _mm256_madd_epi16(ones, dot);

// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(i32);
// Do the product:
__m256 p = dotMul(bx, by);

// Apply the scale, and accumulate
acc = _mm256_fmadd_ps(scale, p, acc);
}

// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps(acc, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
sumf = _mm_cvtss_f32(res);
sumf = horizontalSum(acc);
#else
for (int i = 0; i < nb; i++) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
Expand All @@ -2270,6 +2293,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
*s = sumf;
}

// Lookup table used to convert q3_0 to SIMD vectors.
// Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
// A zero bit turns into 0xFC, while a one bit turns into 0x00.
#define B0(n) 0x ## n
#define B1(n) B0(n ## FC), B0(n ## 00)
#define B2(n) B1(n ## FC), B1(n ## 00)
#define B3(n) B2(n ## FC), B2(n ## 00)
#define B4(n) B3(n ## FC), B3(n ## 00)
#define B5(n) B4(n ## FC), B4(n ## 00)
#define B6(n) B5(n ## FC), B5(n ## 00)
#define B7(n) B6(n ## FC), B6(n ## 00)
#define B8( ) B7( FC), B7( 00)
static const uint64_t ggml_q3_table[256] = { B8() };

static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK3_0 == 0);
const int nb = n / QK3_0;
Expand All @@ -2282,103 +2319,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *

#if defined(__AVX2__)
// Initialize accumulator with zeros
__m128 acc = _mm_setzero_ps();
__m256 acc = _mm256_setzero_ps();

for (int i = 0; i < nb/2; i++) {
const __m128 scale_y = _mm_set1_ps(y[i].d);
for (int u = 0; u < 2; u++) { // let the compiler unroll this
// Compute combined scale for the block
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d));
const __m128 scale = _mm_mul_ps(scale_x, scale_y);

__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs);

// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale

// shift the copies to be able to reach all values
// 255 192 128 64 0
// | | | |
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
// _______________________sssssfedcba98765432__________________________________________ shift right
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));

// add to itself in masked places to shift some values left one bit
// 127 64 0
// | | | | | | | | | | | | | | | |
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
//
// 255 192 128
// | | | | | | | | | | | | | | | |
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));

// collect 16 bytes from 256 into 128 bits
const __m256i shufmask = _mm256_set_epi8(
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
bxx = _mm256_shuffle_epi8(bxx, shufmask);

__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));

const __m128i mask = _mm_set1_epi8(7);
bx = _mm_and_si128(mask, bx);

const __m128i off = _mm_set1_epi8(4);
bx = _mm_sub_epi8(bx, off);

const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0));
__m256i bx = bytesFromCrumbs(x[i*2+1].qlo, x[i*2].qlo);

// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
__m256i const bxhi = _mm256_set_epi64x(
ggml_q3_table[x[i*2+1].qhi >> 8], ggml_q3_table[x[i*2+1].qhi & 0xFF],
ggml_q3_table[x[i*2+0].qhi >> 8], ggml_q3_table[x[i*2+0].qhi & 0xFF]);

// Convert int16_t to int32_t by adding pairwise
const __m128i ones = _mm_set1_epi16(1);
__m128i i32 = _mm_madd_epi16(dot, ones);
// OR the high bits (which also handles the sign):
bx = _mm256_or_si256(bx, bxhi);

// Convert int32_t to float
const __m128 p = _mm_cvtepi32_ps(i32);
// Compute combined scale for the block
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d));
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d));
__m256 scale = _mm256_set_m128(scale_hi, scale_lo);
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d));

// Apply the scale, and accumulate
acc = _mm_fmadd_ps(scale, p, acc);
}
// Load y vector
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

// Do the product,
__m256 p = dotMul(bx, by);

// Apply the scale, and accumulate
acc = _mm256_fmadd_ps(scale, p, acc);
}

// Return horizontal sum of the acc vector
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
sumf = _mm_cvtss_f32(res);
sumf = horizontalSum(acc);
#else
for (int i = 0; i < nb; i++) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
const float d1 = y[i/2].d;

uint64_t qs0 = x[i].qs;
uint_fast32_t lo0 = x[i].qlo;
uint_fast32_t hi0 = x[i].qhi << 2;
const int8_t * restrict p1 = y[i/2].qs + (i%2)*QK3_0;

int sumi = 0;
for (int j = 0; j < QK3_0; j++) {
const int8_t i0 = (int8_t)(qs0 & 7) - 4;
const int_fast16_t i1 = p1[j];
for (int l = 0; l < QK3_0; l++) {
const int8_t i0 = (int8_t)((lo0 & 3) | ((hi0 & 4) - 4));
const int_fast16_t i1 = p1[l];

sumi += i0 * i1;

qs0 >>= 3;
lo0 >>= 2;
hi0 >>= 1;
}
sumf += d0 * d1 * sumi;
}
Expand Down Expand Up @@ -11630,11 +11618,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
quantize_row_q3_0(src + j, y, k);

for (int i = 0; i < nb; i++) {
uint64_t qs = y[i].qs;
uint_fast32_t lo = y[i].qlo;
uint_fast32_t hi = y[i].qhi << 2;
for (int l = 0; l < QK3_0; l++) {
const int8_t vi = qs & 7;
int8_t vi = (lo & 3) | (hi & 4);
hist[vi]++;
qs >>= 3;
lo >>= 2;
hi >>= 1;
}
}
}
Expand Down