Skip to content

Conversation

@GodEmperor785
Copy link
Contributor

This is for issue #7309 which I made

Adding new slider for --ubatch-size parameter of llama-server in llama.cpp loader, right below batch_size.
Setting this to higher value along with batch_size allows for much better performance of prompt processing of mixed CPU+GPU inference (I've seen around 3x gain in speed - as mentioned in the issue).

Default value is set to default of llama-server which can be seen here

I think it would also be good to consider if default for batch_size slider ( --batch-size parameter in llama-server) should be only 256 in webui, when llama-server itself has default value of 2048 (Can also be seen in above link)?

Checklist:

Added new slider for --ubatch-size parameter of llama-server in llama.cpp loader
@oobabooga
Copy link
Owner

Thoughts on this? #6870

The default batch size had been lowered to 256 because 2048 seemed like a value for batching/servers/multiple users that led to bottlenecking and worse performance on both full offloading and GPU + CPU. Maybe things behave differently for MoE models?

@GodEmperor785
Copy link
Contributor Author

I can say that for mixed CPU+GPU inference of MoE model, increasing batch_size gave me at least 2x performance gain and ubatch_size increased it 3 times more. With defaults (batch_size 256, ubatch_size unset so 512) I got less than 100 T/s on GLM-4.5-Air, now I get over 500 T/s.
In my case it is single user, one prompt, consumer CPU (i7 13700k), consumer DDR5 RAM.

For threads setting, I hear it is generally recommended to set it to less cores than the CPU has, at maximum to number of cores. So I didn't change these values in my tests. Some big server CPUs might work better with more threads.

Later today I will quickly test some batch sizes (both batch_size and ubatch_size) with dense models and post some results

@GodEmperor785
Copy link
Contributor Author

GodEmperor785 commented Nov 19, 2025

I ran some tests using Magistral Small 24B, Q6 quant.
I used llama-bench because it is easier and more reliable to print some stats (llama-bench is part of llama.cpp installed by the webui).
Tests made on consumer HW, single user (mentioned in previous comment).

Here are results for prompt processing (PP) (values were rounded, so it is +-10):

  • b X is batch_size
  • ub X is ubatch_size
  • x means I didn't check
  • values in tokens/s
  • thread counts were left at defaults

For full GPU offload (prompt length 8k):

b 256 b 512 b 2048
ub 256 3850 x x
ub 512 3720 3990 3970
ub 2048 3680 3970 4030

Token generation (TG) consistently stayed at 72 T/s, for prompt processing there was some noticeable variance so I'd set margin of error at 100 T/s.
Observations: Small differences for full GPU inference, but batch_size of 512 or greater consistently gave at least 200 T/s more

For partial offload to CPU (35/40 layers on GPU, prompt length 2k because it is slower):

b 256 b 512 b 2048
ub 256 1070 x x
ub 512 1070 1750 1750
ub 2048 1070 1760 3070

Token generation (TG) consistently stayed at 20 T/s, PP margin of error around 10 T/s
Observations: Noticeable improvement with increasing both batch sizes, especially when both are at 2048, but 512 also gives nice speedup.

And finally only CPU (prompt length 512 (default) and generate 32 tokens to not take forever):

b 256 b 512 b 2048
ub 256 200 x x
ub 512 190 390 390
ub 2048 200 400 380

Token generation (TG) consistently stayed at 4 T/s, PP margin of error around 10 T/s to be safe.
Observations: Noticeable improvement at batch size 512, no improvement for bigger values, but also no drop of performance either.

Based on this short test I'd say that batch_szie could be safely increased to 512. What do you think about this?
I could have got such increase at batch 2048 due to powerful GPU, but even tests on CPU-only show big gain at batch_size 512.

I think the "lockup" on consumer hardware mentioned in #6870 was due to number of threads, in my tests I left threads settings at default.
Also a question: What LLM model and HW were tests in #6870 made on?

@oobabooga
Copy link
Owner

I have run some new measurements with MoE models, in Q8_0 precision on a RTX 6000 Ada GPU:

Model Batch Layers PP (t/s) TG (t/s)
Qwen3-30B 256 25 288.64 11.83
512 25 509.82 11.87
1024 25 906.80 11.86
2048 25 1252.61 12.00
Qwen3-30B 256 49 3917.27 130.12
512 49 5357.61 130.55
1024 49 6803.12 131.03
2048 49 7098.52 135.00
gpt-oss-20b 256 12 606.67 79.73
512 12 982.12 78.73
1024 12 1563.27 89.30
2048 12 1945.40 15.38
gpt-oss-20b 256 25 6004.69 979.50
512 25 7737.23 501.36
1024 25 8771.32 777.94
2048 25 8545.93 1082.82

If I apply the same harmonic mean score as in #6870, I get

Rank Batch Size Harmonic Mean
1 1024 0.879
2 512 0.654
3 2048 0.623
4 256 0.530

Since the SOTA models that run on consumer hardware are all MoE models nowadays, I think it's fine to use 1024 as the default, even though possibly 256 or 512 are better for dense models.

@oobabooga
Copy link
Owner

oobabooga commented Nov 21, 2025

Added some new tests for 2048/512 and 256/512 (llama.cpp default and this project's default so far) to make sure and 1024/1024 still wins:

Model Layers Config (Batch / Ubatch) PP (t/s) TG (t/s)
gpt-oss-20b 12 256 / 256 606.67 79.73
256 / 512 604.43 79.85
512 / 512 982.12 78.73
1024 / 1024 1563.27 89.30
2048 / 512 991.71 78.83
2048 / 2048 1945.40 15.38
gpt-oss-20b 25 256 / 256 6004.69 979.50
256 / 512 5909.47 983.06
512 / 512 7737.23 501.36
1024 / 1024 8771.32 777.94
2048 / 512 7808.12 505.23
2048 / 2048 8545.93 1082.82
Qwen3-30B 25 256 / 256 288.64 11.83
256 / 512 290.32 11.87
512 / 512 509.82 11.87
1024 / 1024 906.80 11.86
2048 / 512 512.40 11.91
2048 / 2048 1252.61 12.00
Qwen3-30B 49 256 / 256 3917.27 130.12
256 / 512 3886.14 129.52
512 / 512 5357.61 130.55
1024 / 1024 6803.12 131.03
2048 / 512 5406.50 133.89
2048 / 2048 7098.52 135.00
Rank Config (Batch / Ubatch) Harmonic Mean Score
1 1024 / 1024 0.879
2 2048 / 512 0.659
3 512 / 512 0.654
4 2048 / 2048 0.623
5 256 / 256 0.530
6 256 / 512 0.529

Thanks for the PR -- this is a nice performance improvement on MoE models.

@oobabooga oobabooga merged commit 400bb06 into oobabooga:dev Nov 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants