-
Notifications
You must be signed in to change notification settings - Fork 70
New IQ2_KT #511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New IQ2_KT #511
Conversation
The new trellis generates int8_t values via sum_as_uint8_t[(ka * idx + kb) & 0x3f33f3f3f] - 126. CUDA dequantize works. AVX2 case Ny > 32 works, and we get 273 t/s for L3-8B. PPL is on par or even slightly lower than original QTIP trellis.
We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic. Still somewhat slower than other quants, but no longer pathetic.
We get very respectable PP-512 = 120 t/s. TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant.
We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis.
It seems Apple Silicon cannot quickly add 4 8-bit ints. Or I don't know how to do it - but I didn't find anything in the Metal Shading Language Specification. So, performance is quite a bit worse than the original trellis.
Just kicked the tires on this PR and looks good so far!
There is not a lot of info about this model, and honestly it doesn't behave like a 4bpw QAT and they don't have much details (i asked on their hf repo). Their chat template stuff seems wonky too, (but that is unrelated to this PR). (might need to use the Anyway, the important thing is the new I'll run some sweep benches too for speed comparisons. |
Speed benchmarks on Single CUDA RTX A6000 48GB VRAM fully offloaded. 👈 Logsgit checkout ik/new_iq2kt
cmake -B ./build -DGGML_CUDA=ON -DGGML_BLAS=OFF -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_F16=ON
cmake --build ./build --config Release -j $(nproc)
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-BF16-00001-of-00002.gguf
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-Q4_0.gguf
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-Q4_K.gguf
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-IQ4_K.gguf
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-IQ4_KS.gguf
#model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-IQ4_KT.gguf
model=/mnt/raid/models/ubergarm/OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT-GGUF/DeepSeek-R1-0528-Distill-Qwen3-32B-Preview0-QAT-IQ2_KT.gguf
CUDA_VISIBLE_DEVICES="0" \
./build/bin/llama-sweep-bench \
--model "$model" \
--ctx-size 17408 \
-ctk f16 -ctv f16 \
-fa \
-ngl 99 \
--warmup-batch \
--threads 1 Q4_0
Q4_K
IQ4_K
IQ4_KS
IQ4_KT
IQ2_KT
Nice job, the Somewhat related I saw further discussions on optimizing QTIP style quants by using pre-computed Hessians for each layer/tensor. Zero pressure to look or distract, just interesting folks are already uploading Hessians for some models. |
This is the sort of thing we do not want to do here. It leads to overfitting, needs a huge amount of compute, which makes it inaccessible for the average enthusiast, so basically only good for pushing out yet another paper to arXiv. |
Great work! Love seeing improved performance on the trellis quants ik. Some alternate MCG multipliers (with no addition) have lower PPL than QTIP 3INST defaults: Meta-Llama-3.1-8B-Instruct
Just chiming in because it might be a great time to take the 0.5% higher fidelity of ditching the default QTIP multiplier+addition params if you're already introducing a breaking change to IQx_KT quants anyway. For IQ2_K, this gains back a good chunk of what was lost by switching to your new decoder scheme, while also making IQ3_KT and IQ4_KT both better than #511 and in some cases even better than prior versions. Also, ka =
|
@louiehelm Thank you for the comment, looks very promising. It should also improve performance slightly by saving one integer addition. Do I understand correctly that you applied the new multiplier to PR #511 instead of the original implementation on the main branch? Did you also try models other than LlaMA-3.1-8B-Instruct? |
Yes initial tests above were on #511. Needs more testing... Qwen3 1.7B IQ2_KT = 2.5% lower PPL.... Magistral 24B IQ2_KT = 50% lower PPL [default model bugged perhaps?] |
On Gemma 3 27b qat unquantized (iq2_kt for ffn_up, ffn_gate, attn_q, attn_k and attn_o, iq4_ks for ffn_down, q4_0 for attn_v, and q6 for embed/output), I obtained an almost equivalent perplexity wikitest 512 between the original couple ka/kb and louiehelm's. But on a Llama 3.3 70b type model (iq2_kt for the ffns, attn_q and attn_o, q6 for embedding, iq5_ks_r4 for output and attn_v, and iq4_ks_r4 for attn_k), the final wikitest 512 perplexity is 1% lower with ka = 3417055213 and kb = 0 compared to the original couple. With an IQ3_KT with a Cuda MMQ Kernel, and ffn_down/attn_o in iq3_KT, a Llama 3 70b on mono 24GB GPU will become really viable. |
1% of what? Can you give the specific PPL values? |
Here is : For Llama 3.3 70b type model (a merge, not the original 3.3 70b ; iq2_kt for the ffns, attn_q and attn_o, q6 for embedding, iq5_ks_r4 for output and attn_v, and iq4_ks_r4 for attn_k).
For Gemma 3 27b qat unquantized (iq2_kt for ffn_up, ffn_gate, attn_q, attn_k and attn_o, iq4_ks for ffn_down, q4_0 for attn_v, and q6 for embed/output).
|
Did you also try I tried LlaMA-3.1-8B-Instruct and PPL goes up by ~0.5%, which is a lot for 4 bit. I only changed the CUDA implementation so I can run PPL. When I make the change in the CPU code I'll push to a new branch. Probably tomorrow. |
Just got home and tried louiehelm's 0xCBAC1FED patch on this PR511. Patch👈 `0xCBAC1FED` Patchdiff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index a602e47d..45de337e 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -341,15 +341,15 @@ inline __device__ int nearest_int(float fval) {
}
int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
val = ka*val + kb;
return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
}
float __device__ __forceinline__ trellis_next(uint32_t& val) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
uint32_t s;
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
index 50e6458d..5e0226ed 100644
--- a/ggml/src/ggml-cuda/dmmv.cu
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -16,8 +16,8 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#endif
static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
val = ka*val + kb;
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index df1cea89..34402358 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -398,8 +398,8 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t km = 0x3f3f3f3f;
float scale = *(const float *)vbq;
@@ -436,8 +436,8 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t km = 0x3f3f3f3f;
float scale = *(const float *)vbq;
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index e2c76a85..2b5a6df5 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2799,8 +2799,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t km = 0x3f3f3f3f;
#ifdef INT8_MMA_AVAILABLE
@@ -2872,8 +2872,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kt(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t km = 0x3f3f3f3f;
#ifdef INT8_MMA_AVAILABLE
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp
index 8b8cae14..41b9b2d6 100644
--- a/ggml/src/iqk/iqk_gemm_ktquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp
@@ -14,8 +14,8 @@
namespace {
inline uint32_t trellis_next(uint32_t& val) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
val = val*ka + kb;
@@ -31,8 +31,8 @@ inline float trellis_gen(uint32_t& val, uint32_t* s) {
struct Trellis1 {
constexpr static uint32_t kmask = 0x8fff8fff;
constexpr static uint32_t km32 = 0x3b603b60;
- constexpr static uint32_t ka = 89226354;
- constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka = 3417055213;
+ constexpr static uint32_t kb = 0;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
@@ -76,8 +76,8 @@ inline __m256 trellis_gen8(__m256i i8) {
struct Trellis2 {
constexpr static uint32_t kmask = 0x8fff8fff;
constexpr static uint32_t km32 = 0x3b603b60;
- constexpr static uint32_t ka = 89226354;
- constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka = 3417055213;
+ constexpr static uint32_t kb = 0;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
@@ -100,8 +100,8 @@ struct Trellis2 {
template <bool is_8 = false>
struct Trellis3 {
- constexpr static uint32_t ka = 89226354;
- constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka = 3417055213;
+ constexpr static uint32_t kb = 0;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
@@ -913,8 +913,8 @@ namespace {
struct Trellis1 {
constexpr static uint32_t kmask = 0x8fff8fff;
constexpr static uint32_t km32 = 0x3b603b60;
- constexpr static uint32_t ka = 89226354;
- constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka = 3417055213;
+ constexpr static uint32_t kb = 0;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
@@ -1419,8 +1419,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
struct Trellis3 {
- constexpr static uint32_t ka = 89226354;
- constexpr static uint32_t kb = 64248484;
+ constexpr static uint32_t ka = 3417055213;
+ constexpr static uint32_t kb = 0;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index b6bff0a1..7c052989 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -7454,8 +7454,8 @@ public:
inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const;
static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
+ constexpr uint32_t ka = 3417055213;
+ constexpr uint32_t kb = 0;
uint32_t x = i + offset;
if constexpr (is_int) {
uint32_t s; DataHere is the comparison of the same OpenBuddy-R1-0528-Distill-Qwen3-32B-Preview0-QAT used above between regular PR511 and the patched version. PR511 (from above)
0xCBAC1FED Patch
Comparison
ConclusionWell, its hard to say for a single run given the deltas seem within the margin of error. I'm not sure if it is possible/worthwhile to save the |
Closing in favor of #529 |
This PR uses the new trellis introduced in #505 and applies it to
IQ2_KT
.This leads to a slightly higher PPL for the models where the
IQ2_KT
on the main branch works, but is more stable and there are no longer NaNs for the models where the existingIQ2_KT
was failing (Qwen3-30B-A3B and DeepSeek-Lite).Performance is also great, except on the Apple GPU, where it is slower than the original
IQ2_KT
implementation. But on CUDA and on the CPU there are massive performance gains. Here an example of LLaMA-3.1-8B on RTX-4080 and Ryzen-7950X