diff --git a/CMakeLists.txt b/CMakeLists.txt index 6bec1f97befd9..41958c93a1bc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,7 +235,9 @@ endif() add_library(ggml OBJECT ggml.c - ggml.h) + ggml.h + ggml_extra.h + ggml_extra.cpp) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump diff --git a/Makefile b/Makefile index 3e58a28a751ab..17624656bd158 100644 --- a/Makefile +++ b/Makefile @@ -145,32 +145,35 @@ ggml.o: ggml.c ggml.h llama.o: llama.cpp llama.h llama_util.h llama_internal.h $(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o +ggml_extra.o: ggml_extra.cpp ggml_extra.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c examples/common.cpp -o common.o clean: rm -vf *.o main quantize quantize-stats perplexity embedding -main: examples/main/main.cpp ggml.o llama.o common.o - $(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o llama.o common.o -o main $(LDFLAGS) +main: examples/main/main.cpp ggml.o llama.o common.o ggml_extra.o + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo -quantize: examples/quantize/quantize.cpp ggml.o llama.o - $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS) +quantize: examples/quantize/quantize.cpp ggml.o llama.o ggml_extra.o + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o - $(CXX) $(CXXFLAGS) examples/quantize-stats/quantize-stats.cpp ggml.o llama.o -o quantize-stats $(LDFLAGS) +quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o ggml_extra.o + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o - $(CXX) $(CXXFLAGS) examples/perplexity/perplexity.cpp ggml.o llama.o common.o -o perplexity $(LDFLAGS) +perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o ggml_extra.o + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o - $(CXX) $(CXXFLAGS) examples/embedding/embedding.cpp ggml.o llama.o common.o -o embedding $(LDFLAGS) +embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o ggml_extra.o + $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -libllama.so: llama.o ggml.o - $(CXX) $(CXXFLAGS) -shared -fPIC -o libllama.so llama.o ggml.o $(LDFLAGS) +libllama.so: llama.o ggml.o ggml_extra.o + $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) # # Tests # diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 203bfe8cc1057..5789bd9ea392c 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,6 +1,7 @@ #include "ggml.h" #include "llama.h" #include "llama_internal.h" +#include "ggml_extra.h" #include #include @@ -29,7 +30,7 @@ struct quantize_stats_params { std::vector include_types; }; -const int64_t SCRATCH_ELEMENTS = 32*32; +const int64_t SCRATCH_ELEMENTS = 32*32*256; // So we use multi-threading in a meaningful way in the new quantization const size_t HISTOGRAM_BUCKETS = 150; const double HISTOGRAM_RANGE = 0.03; @@ -184,6 +185,7 @@ int main(int argc, char ** argv) { // read command line bool invalid_param = false; + bool checkNewQuantization = false; std::string arg; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -232,6 +234,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "error: %s not in list of types\n", argv[i]); invalid_param = true; } + } else if (arg == "-nq" || arg == "--new-quantization") { + checkNewQuantization = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); quantize_stats_print_usage(argc, argv); @@ -302,11 +306,24 @@ int main(int argc, char ** argv) { std::vector output_scratch(SCRATCH_ELEMENTS); // loop throught quantization types - for (int i = 0; i < GGML_TYPE_COUNT; i++) { + //for (int i = 0; i < GGML_TYPE_COUNT; i++) { + for (int i = 0; i < 1; i++) { if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) { continue; } quantize_fns_t qfns = ggml_internal_get_quantize_fn(i); + if (i < 2 && checkNewQuantization) { + //qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0 : kQuantizeQ4_1; + //qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0 : kQuantizeQ5_1; + ////qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0 : kQuantizeQ5_1_Fast; + //if (i == 1) qfns.dequantize_row_q = kDequantizeQ5_1; + //qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0K : kQuantizeQ5_1; + //qfns.dequantize_row_q = i == 0 ? kDequantizeQ4_0K : kDequantizeQ5_1; + //qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0K : kQuantizeQ4_1K; + //qfns.dequantize_row_q = i == 0 ? kDequantizeQ4_0K : kDequantizeQ4_1K; + qfns.quantize_row_q = i == 0 ? kQuantizeQ8Simple: kQuantizeQ4_1K; + qfns.dequantize_row_q = i == 0 ? kDequantizeQ8: kDequantizeQ4_1K; + } if (qfns.quantize_row_q && qfns.dequantize_row_q) { if (params.verbose) { printf("testing %s ...\n", type_strs[i]); diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 680757c6bf356..313b7534f36c5 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -14,6 +14,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); fprintf(stderr, " type = 2 - q4_0\n"); fprintf(stderr, " type = 3 - q4_1\n"); + fprintf(stderr, " type = 4 - new q4_0\n"); + fprintf(stderr, " type = 5 - new q4_1\n"); return 1; } diff --git a/ggml.c b/ggml.c index 897b67d930614..04a3a40f767d5 100644 --- a/ggml.c +++ b/ggml.c @@ -2,6 +2,7 @@ #define _GNU_SOURCE #include "ggml.h" +#include "ggml_extra.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -502,6 +503,13 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +inline int nearestInt(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { assert(k % QK == 0); @@ -526,8 +534,15 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r const float v0 = x[i*QK + l + 0]*id; const float v1 = x[i*QK + l + 1]*id; + // On x86_64 and x86, round is amazingly slow. + // Here it is best to just use this: + //const uint8_t vi0 = (uint8_t)(v0 + 8.5f); + //const uint8_t vi1 = (uint8_t)(v1 + 8.5f); const uint8_t vi0 = (int8_t)roundf(v0) + 8; const uint8_t vi1 = (int8_t)roundf(v1) + 8; + // This is marginally slower (but still much faster than round()) + //const uint8_t vi0 = nearestInt(v0) + 8; + //const uint8_t vi1 = nearestInt(v1) + 8; assert(vi0 < 16); assert(vi1 < 16); @@ -818,6 +833,10 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric const float v0 = (x[i*QK + l + 0] - min)*id; const float v1 = (x[i*QK + l + 1] - min)*id; + // For some reason round() is amazingly slow on X86_64 and x86 + // Using this instead reduces the difference between AVX2 and scalar to less than ~15% + //const uint8_t vi0 = nearestInt(v0); + //const uint8_t vi1 = nearestInt(v1); const uint8_t vi0 = roundf(v0); const uint8_t vi1 = roundf(v1); @@ -2569,7 +2588,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { 1, }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 7"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(block_q4_0), @@ -2582,7 +2601,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { }; // don't forget to update the array above when adding new types -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 7"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", diff --git a/ggml_extra.cpp b/ggml_extra.cpp new file mode 100644 index 0000000000000..fa2591e68255e --- /dev/null +++ b/ggml_extra.cpp @@ -0,0 +1,673 @@ +#include "ggml_extra.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr int kChunkSize = 32*32*8; +constexpr int QK = 32; +constexpr int kBucketSize0 = QK/2 + sizeof(float); +constexpr int kBucketSize1 = QK/2 + 2*sizeof(float); + +inline int toNearestInt(float fval) { + assert(fval <= 4194303.f); + constexpr float kSnapper=3<<22; + auto val = fval + kSnapper; + int i; std::memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +// Adapted from PR #835, function quantize_row_q4_0_rmse() +// +// I absolutely cannot reproduce the rmse = 0.00185915 reported in #835. +// Instead, I get rmse = 0.00197 with the original and rmse = 0.00192 // with the modification that determines the scale actually minimizing +// the rmse. +// +// Do I have a bug? iI don't see it. +// The only difference is that I'm using toNearestInt() +// instead of round(), but what are the odds for getting scaled weights at +// exactly 2.5, 4.5, and 6.5, where toNearestInt() and round() differ. +// (with toNearestInt() behaving as expected and rounding towards the even integer, +// while round() always rounding up. +float quanizeRmse(int n, const float* X, int8_t* L) { +#define Q4_0_SCALE_CANDIDATE_COUNT 8 + static const float candidates[Q4_0_SCALE_CANDIDATE_COUNT] = { -8.7f, -8.5f, -8.3f, -8.1f, -7.9f, -7.7f, -7.2f, +7.0f }; + float max = 0, amax = 0; + for (int i=0; i amax) { amax = ax; max = X[i]; } + } + if (!amax) { // all zero + for (int i=0; i::max(), bestScale = 0; + for (int si=0; si best*suml2) { + best = sumlx*sumlx/suml2; bestScale = iscale; + } + } + float sumlx = 0; int suml2 = 0; + for (int i=0; i 0 && L[i] < nmax) { + auto s1 = sumlx + X[i]; + auto s2 = suml2 + 2*L[i] + 1; + if (s2 > 0 && s1*s1 > best*s2) { + scale = s1/s2; best = scale*s1; ++L[i]; sumlx = s1; suml2 = s2; haveChanges = true; + } + } + else if (g < 0 && L[i] > nmin) { + auto s1 = sumlx - X[i]; + auto s2 = suml2 - 2*L[i] + 1; + if (s2 > 0 && s1*s1 > best*s2) { + scale = s1/s2; best = scale*s1; --L[i]; sumlx = s1; suml2 = s2; haveChanges = true; + } + } + } + if (!haveChanges) break; + } + return scale; +} +// The following improves the above. +// It gives RMSE = 0.00185228 for the 7B model. +float quanizeRmseK7(int n, const float* X, int8_t* L) { + constexpr int kCandiateCount = 20; + static const float candidates[kCandiateCount] = { -8.7f, -8.5f, -8.3f, -8.1f, -7.9f, -7.7f, -7.2f, -7.0f, -6.3f, -5.7f, + +8.7f, +8.5f, +8.3f, +8.1f, +7.9f, +7.7f, +7.2f, +7.0f, +6.3f, +5.7f}; + return quanizeRmseK(n, X, L, kCandiateCount, candidates, -8, 7); +} + +float quanizeRmseK15(int n, const float* X, int8_t* L) { + constexpr int kCandiateCount = 16; + static const float candidates[kCandiateCount] = { + +17.75f, +17.25f, +16.75f, +16.25f, +15.75f, +15.25f, +14.75f, +14.25f, +13.75f, +13.25f, +12.75f, +12.25, +11.75f, + +11.25f, +10.75f, +10.25f + }; + return quanizeRmseK(n, X, L, kCandiateCount, candidates, 0, 15); +} + +float quanizeRmseK31(int n, const float* X, int8_t* L) { + constexpr int kCandiateCount = 24; + static const float candidates[kCandiateCount] = { + +35.25, +34.25f, +33.25f, +32.75f, +32.25f, +31.75f, +31.25f, +30.75f, +30.25f, +29.75f, +29.25f, +28.25f, +27.25f, +26.25f, + +25.25f, +24.25f, +23.25, +22.25f, +21.25f, +20.25f, +19.25f, +18.25f, +17.25f, +16.25f + }; + //static const float candidates[kCandiateCount] = { + // +33.25f, +32.25f, +31.75f, +31.25f, +30.75f, +30.25f, +30.25f, +29.25f, +28.75f, +27.25f, +26.25f, +25.25f, +24.25f, +23.25, +22.25f, + // +21.25f + //}; + return quanizeRmseK(n, X, L, kCandiateCount, candidates, 0, 31); +} + +// Fast (as much faster than doing the optimization), but not very good. +float quanizeRmseFast(int n, const float* X, int8_t* L) { + //constexpr int kCandiateCount = 3; + //static const float candidates[kCandiateCount] = { +8.3f, +7.2f, +5.7f}; + constexpr int kCandiateCount = 4; + static const float candidates[kCandiateCount] = { +8.7f, +7.9f, +7.2f, +5.7f}; + float max = 0; + for (int i=0; i= sumxlm*sumxlm*suml2p) { + if (sumxlp*sumxlp > best*suml2p) { + best = sumxlp*sumxlp/suml2p; bestScale = iscale; + } + } else { + if (sumxlm*sumxlm > best*suml2m) { + best = sumxlm*sumxlm/suml2m; bestScale = -iscale; + } + } + } + float sumlx = 0; int suml2 = 0; + for (int i=0; i>& work) { + work.clear(); + work.reserve(n*17); + for (int l=-8; l<=8; ++l) { + float scale = l - 0.4999f; + for (int i=0; i 0 && sumlx*sumlx > best*suml2) { + best = sumlx*sumlx/suml2; bestScale = s; + } + } + } + sumlx = 0; suml2 = 0; + for (int i=0; i kQuantize0(int n, const float* X, int8_t* L, std::vector>& work, int nmin, int nmax) { + work.clear(); + work.reserve(n*(nmax+2)); + float max = 0; int imax = -1; + for (int i=0; i max) { max = x; imax = i; } + } + if (imax < 0) { // all X are zero + for (int i=0; i 0) { + kmin = nmax-2; kmax = nmax+1; + } else { + kmin = nmax/2; kmax = nmax+1; + } + } + for (int k=kmin; k<=kmax; ++k) work.push_back({(k + 0.501f)*maxi, imax}); + float minScale = work.front().first; + float maxScale = work.back().first; + for (int i=0; i maxScale) break; + if (s > minScale) work.push_back({s,i}); + } + } + std::sort(work.begin(), work.end()); + float sumlx = 0; int suml2 = 0; + float s = work.front().first; + for (int i=0; i L[i]) { + sumlx += X[i]; + suml2 += 1 + 2*L[i]; + } + else { + sumlx -= X[i]; + suml2 += 1 - 2*L[i]; + } + L[i] = l; + float sumlx2 = sumlx*sumlx; + if ((s != lasts || k == int(work.size())-1) && suml2 > 0 && sumlx2*bestSuml2 > bestSumlx2*suml2) { + bestSumlx = sumlx; bestSumlx2 = sumlx2; bestSuml2 = suml2; bests = s; + } + lasts = s; + } + for (int i=0; i kQuantize1(int n, const float* X, int8_t* L, std::vector& tmpX, + std::vector>& work, int nmax) { + float min = X[0], max = X[1]; + for (int i=1; i 0 && std::abs(a - aold) < 1e-6*std::abs(aold) && std::abs(b - bold) < 1e-6*std::abs(bold)) break; + } + return {a, b}; +} + +std::pair kQuantize1Fast(int n, const float* X, int8_t* L, int nmax) { + float min = X[0], max = X[1]; + for (int i=1; i>& work, std::vector& tmpX) { + auto q = (uint8_t*)y; + if (type == 0) { + auto scale = quanizeRmseK7(QK, X, L); + //auto scale = quanizeRmseFast(QK, X, L); + //auto scale = quanizeRmseOpt(QK, X, L, work); + // The following is not quite as good as quanizeRmseK() and it is slower too. + //if (int(tmpX.size()) < QK) tmpX.resize(QK); + //auto r1 = kQuantize0(QK, X, L, work, -8, 7); + //for (int i=0; i r1.first) { + // scale = -r2.first; + // std::memcpy(L, L2, QK); + //} + ////float scale = kQuantize0(QK, X, L, work, -7, 7); + std::memcpy(q, &scale, sizeof(scale)); q += sizeof(scale); + for (int k=0; k 15) { l1 -= 16; *u |= m; } + m <<= 1; + if (l2 > 15) { l2 -= 16; *u |= m; } + m <<= 1; + q[k] = l1 | (l2 << 4); + } + } + }; + + auto bucketSize = type == 0 || type == 4 ? kBucketSize0 : kBucketSize1; + auto y = (char*)buffer; + int nchunk = (k + kChunkSize-1)/kChunkSize; + if (nchunk < 2) { + std::vector L(QK); + std::vector> work; + std::vector tmpX; + int nb = k / QK; + auto x = X; + for (int i=0; i counter(0); + auto compute = [&counter, X, y, k, bucketSize, &processOne] () { + std::vector L(QK); + std::vector> work; + std::vector tmpX; + while (true) { + int first = counter.fetch_add(kChunkSize, std::memory_order_relaxed); + if (first >= k) break; + int last = first + kChunkSize; + if (last > k) last = k; + auto xi = X + first; + auto yi = y + (first/QK)*bucketSize; + int n = (last - first)/QK; + for (int i=0; i workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); +} + +void collectHisto(int k, const void* buffer, int64_t* hist, int type) { + if (!hist) return; + auto y = (const uint8_t*)buffer; + int m = type == 0 ? 4 : 8; + int n = k / 32; + for (int i=0; i> 4]; + } + y += 16; + } +} + +} + +extern "C" { + +void kQuantizeQ4_0(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 0); +} + +void kQuantizeQ4_1(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 1); +} + +void kQuantizeQ5_1(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 2); +} + +void kQuantizeQ5_1_Fast(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 3); +} + +size_t kQuantizeQ4_0H(const float* x, void* buffer, int k, int64_t* hist) { + kQuantizeQ4(x, buffer, k, 0); + collectHisto(k, buffer, hist, 0); + return (k / QK) * kBucketSize0; +} + +size_t kQuantizeQ4_1H(const float* x, void* buffer, int k, int64_t* hist) { + kQuantizeQ4(x, buffer, k, 1); + collectHisto(k, buffer, hist, 1); + return (k / QK) * kBucketSize1; +} + +size_t kQuantizeQ5_1H(const float* x, void* buffer, int k, int64_t* hist) { + kQuantizeQ4(x, buffer, k, 2); + collectHisto(k, buffer, hist, 1); + return (k / QK) * kBucketSize1; +} + +size_t kQuantizeQ5_1H_Fast(const float* x, void* buffer, int k, int64_t* hist) { + kQuantizeQ4(x, buffer, k, 3); + collectHisto(k, buffer, hist, 1); + return (k / QK) * kBucketSize1; +} + +void kDequantizeQ5_1(const void* x, float* y, int k) { + assert(k % QK == 0); + int n = k / QK; + auto data = (const uint8_t*)x; + for (int i=0; i> 4; + if (u & m) l1 += 16; + m <<= 1; + if (u & m) l2 += 16; + m <<= 1; + *y++ = a + b*l1; + *y++ = a + b*l2; + } + data += 16; + } +} + +void kQuantizeQ4_0K(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 4); +} + +void kDequantizeQ4_0K(const void* x, float* y, int k) { + assert(k % QK == 0); + int n = k / QK; + auto data = (const uint8_t*)x; + for (int i=0; i> 4; + l1 -= 8; l2 -= 8; + *y++ = a*l1; *y++ = a*l2; + } + data += 8; + for (int k=0; k<8; ++k) { + int8_t l1 = data[k] & 15, l2 = data[k] >> 4; + l1 -= 8; l2 -= 8; + *y++ = b*l1; *y++ = b*l2; + } + data += 8; + } +} + +void kQuantizeQ4_1K(const float* x, void* buffer, int k) { + kQuantizeQ4(x, buffer, k, 5); +} + +void kDequantizeQ4_1K(const void* x, float* y, int k) { + assert(k % QK == 0); + int n = k / QK; + auto data = (const uint8_t*)x; + for (int i=0; i> 4; + *y++ = a1 + b1*l1; *y++ = a1 + b1*l2; + } + data += 8; + for (int k=0; k<8; ++k) { + int8_t l1 = data[k] & 15, l2 = data[k] >> 4; + *y++ = a2 + b2*l1; *y++ = a2 + b2*l2; + } + data += 8; + } +} + +void kQuantizeQ8Simple(const float* x, void* y, int k) { + assert(k % QK == 0); + auto data = (int8_t*)y; + int n = k / (QK/2); + for (int i=0; i 0) { + float iscale = 127.f/max; + float scale = max/127.f; + std::memcpy(data, &scale, sizeof(scale)); data += sizeof(scale); + for (int k=0; k<16; ++k) data[k] = toNearestInt(iscale * *x++); + data += 16; + } else { + float scale = 1; + std::memcpy(data, &scale, sizeof(scale)); data += sizeof(scale); + auto aux = (uint32_t*)data; + aux[0] = aux[1] = aux[2] = aux[3] = 0; + data += 16; + } + } +} + +void kDequantizeQ8(const void* x, float* y, int k) { + assert(k % QK == 0); + auto data = (const int8_t*)x; + int n = k / (QK/2); + for (int i=0; i +#include +extern "C" { +#else +#include +#include +#endif + +#ifdef __cplusplus +// restrict not standard in C++ +#define GGML_RESTRICT +#else +#define GGML_RESTRICT restrict +#endif + +void kQuantizeQ4_0(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +size_t kQuantizeQ4_0H(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k, int64_t* hist); + +void kQuantizeQ4_1(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +size_t kQuantizeQ4_1H(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k, int64_t* hist); + +void kQuantizeQ5_1(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +size_t kQuantizeQ5_1H(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k, int64_t* hist); +void kQuantizeQ5_1_Fast(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +size_t kQuantizeQ5_1H_Fast(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k, int64_t* hist); +void kDequantizeQ5_1(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + +void kQuantizeQ4_0K(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +void kDequantizeQ4_0K(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + +void kQuantizeQ4_1K(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +void kDequantizeQ4_1K(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + +void kQuantizeQ8Simple(const float* GGML_RESTRICT x, void* GGML_RESTRICT y, int k); +void kDequantizeQ8(const void* GGML_RESTRICT x, float* GGML_RESTRICT y, int k); + +#ifdef __cplusplus +} +#endif diff --git a/llama.cpp b/llama.cpp index 54ba01eefbade..04ba10672cbcc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8,6 +8,7 @@ #include "llama_internal.h" #include "ggml.h" +#include "ggml_extra.h" #include #include @@ -1546,9 +1547,12 @@ static llama_vocab::id llama_sample_top_p_top_k( static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) { ggml_type quantized_type; + bool useNewQuantization = false; switch (itype) { case 2: quantized_type = GGML_TYPE_Q4_0; break; case 3: quantized_type = GGML_TYPE_Q4_1; break; + case 4: quantized_type = GGML_TYPE_Q4_0; useNewQuantization = true; break; + case 5: quantized_type = GGML_TYPE_Q4_1; useNewQuantization = true; break; default: throw format("invalid quantization type %d\n", itype); }; @@ -1616,11 +1620,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s switch (new_type) { case GGML_TYPE_Q4_0: { - new_size = ggml_quantize_q4_0(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); + new_size = useNewQuantization ? + kQuantizeQ4_0H(f32_data, new_data, nelements, hist_cur.data()) : + ggml_quantize_q4_0(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); } break; case GGML_TYPE_Q4_1: { - new_size = ggml_quantize_q4_1(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); + new_size = useNewQuantization ? + kQuantizeQ4_1H(f32_data, new_data, nelements, hist_cur.data()) : + ggml_quantize_q4_1(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); } break; default: LLAMA_ASSERT(false);