-
Notifications
You must be signed in to change notification settings - Fork 13k
model : add grok-2 support #15539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model : add grok-2 support #15539
Conversation
@nicoboss Conversion should be working now, mind testing? |
I'm glad to test this model. I'm quite excited about it.
|
Yeah, since they for some reason chose to use a really weird naming scheme and didn't include an BTW, just added the chat template, and it's likely RoPE is wrong (and vocab may need more fixing, we'll see). |
Hmmm, when adding the chat template I noticed sgl-project/sglang#9532 and it looks as if there are more substantial changes (including another magic number), I wonder what the reference for this implementation was? |
As instructed I renamed all files to It complains about:
Inside:
@CISC How did you solve this on your side? |
Didn't expect new norm tensor names, I'll add those later today...
I didn't, I'm working blind on this. :) The initial hope was that it was just a case of configuration changes as they apparently used the same architecture, guess that was not entirely true though. |
I just tried your latest commit. What is strange that if I generate the GGUF anyways it is only 178G while I assumed it to be close to 500 GB. When I try to load it I'm getting:
|
Oh, yeah, sorry, I know why, my bad.
That's weird, this should be mapped from
That's even weirder, like it only added 1 expert. Maybe it's just because the first error stopped it from adding the rest? Edit: LOL, I know why, it's because of the safetensor naming of |
On latest commit When loading the GGUF I'm getting the following:
|
That's not a magic number, that's just sqrt(2) lazily inlined :) |
That makes no sense, duplicated from what? Edit: Oh, ffs, I see, looks like
This finally makes sense though, it looks like each of these are split across 8 safetensors (instead of having 1 expert in 1 safetensor, each expert is split into 8), this requires extra logic to handle. |
@CISC Re ie: g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
return fused_topk(x, g, self.topk, False) So the softcapping is done after the router linear matmul |
I just retested the latest commit and can confirm that |
I'll try to address this tonight, but I might not have the time. |
Don't think we have a good method. I usually input some large source files to fill the context and ask questions about the code, but I can imagine it will be difficult for such a large model. |
You could try:
|
Sigh, I get
half way through filling up the (64k) context. Maybe I'll have more luck if I manage to fully fill each GPU instead of evenly dividing the tensors between them (only filling half with experts on CPU), though not sure how to best achieve that... |
My experiences with llama.cpp and multiple GPUs show that splitting isn’t always handled correctly, and I often need to rebalance -ts in various ways to avoid CUDA out of memory errors |
Thanks, tried playing around with it, but seems to make things worse (couldn't fit context). Got it running with a 48k context (and 40k of tokens) now, is completely coherent and seems to be finding needles just fine, though a little confused about exactly where. Tried again with
(funnily enough it was the model adding the
I think we can conclude that
|
Maybe try with quantized kv cache? Q8_0 shouldn't really affect the quality, but even Q4 might work if you simply want to test coherence past a certain point. It's not optimal, sure, but it surely beats not being able to do any tests at all :) |
I will, but I wanted to avoid that for an initial test because I know it will affect the results negatively. Either way I think showing that the model works perfectly at 6 times the original context is pretty conclusive. |
No luck with |
What is the size of grok-2? 6×24 GB + 512 GB is a lot of memory. Or do you want to keep the full model in VRAM? |
539G BF16, I'm running it with Q8_0 and |
I think it's ready, not overly concerned about |
I was able to run it on 3×3090 (as you can see, I needed to use the tensor-split option to balance the model across the 3 GPUs). It’s quite slow, but it works! speed of grok-2 Q2: 4.8 t/s log$ ~/git/llama.cpp/build_2025.09.04_grok/bin/llama-cli -m grok-2.Q2_K-00001-of-00003.gguf -ts 20/9/10 --n-cpu-moe 25 Human: Hello<|separator|> Assistant: Hi there<|separator|> Human: How are you?<|separator|> Assistant: system_info: n_threads = 12 (n_threads_batch = 12) / 24 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | main: interactive mode on. == Running in interactive mode. ==
llama_perf_sampler_print: sampling time = 4.82 ms / 70 runs ( 0.07 ms per token, 14528.85 tokens per second) |
Hey thanks to all of you, we uploaded preliminary GGUFs for grok-2: https://huggingface.co/unsloth/grok-2-GGUF If there are any issues please let us know! 🙏 |
@slaren gentle ping |
Add support for grok-2.
Tip
Download the files from alvarobartt/grok-2-tokenizer before conversion for BPE tokenizer and chat template
Tip
Rename
pytorch_model-*.safetensors
tomodel-*.safetensors
so thatconvert_hf_to_gguf.py
finds themNote
router_logit_softcapping
andattn_temperature_len
is currently not usedcc/ @nicoboss