Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

This PR uses the new fastdiv code from #15715 for MMVQ + q8_1 quantization, also adds __launch_bounds__ to the quantization.

Performance
GPU Model FlashAttention Microbatch size Test t/s master t/s bc7497a Speedup
MI60 / MI50 llama 1B Q4_0 No 1 pp512 324.58 331.09 1.02
MI60 / MI50 llama 1B Q4_0 No 2 pp512 630.34 617.04 0.98
MI60 / MI50 llama 1B Q4_0 No 3 pp512 550.57 536.87 0.98
MI60 / MI50 llama 1B Q4_0 No 4 pp512 676.31 658.36 0.97
MI60 / MI50 llama 1B Q4_0 No 5 pp512 829.86 842.40 1.02
MI60 / MI50 llama 1B Q4_0 No 6 pp512 728.63 714.23 0.98
MI60 / MI50 llama 1B Q4_0 No 7 pp512 791.68 804.37 1.02
MI60 / MI50 llama 1B Q4_0 No 8 pp512 899.37 902.24 1.00
MI60 / MI50 llama 8B Q4_0 No 1 pp512 86.12 86.92 1.01
MI60 / MI50 llama 8B Q4_0 No 2 pp512 159.63 165.93 1.04
MI60 / MI50 llama 8B Q4_0 No 3 pp512 138.24 138.07 1.00
MI60 / MI50 llama 8B Q4_0 No 4 pp512 163.47 162.29 0.99
MI60 / MI50 llama 8B Q4_0 No 5 pp512 192.49 198.98 1.03
MI60 / MI50 llama 8B Q4_0 No 6 pp512 177.00 172.99 0.98
MI60 / MI50 llama 8B Q4_0 No 7 pp512 190.66 199.64 1.05
MI60 / MI50 llama 8B Q4_0 No 8 pp512 216.54 216.92 1.00
P40 llama 1B Q4_0 Yes 1 pp512 277.78 300.09 1.08
P40 llama 1B Q4_0 Yes 2 pp512 564.38 567.23 1.01
P40 llama 1B Q4_0 Yes 3 pp512 749.81 741.21 0.99
P40 llama 1B Q4_0 Yes 4 pp512 754.37 707.49 0.94
P40 llama 1B Q4_0 Yes 5 pp512 943.11 961.92 1.02
P40 llama 1B Q4_0 Yes 6 pp512 1021.61 1023.67 1.00
P40 llama 1B Q4_0 Yes 7 pp512 1031.67 1059.86 1.03
P40 llama 1B Q4_0 Yes 8 pp512 1025.45 1063.12 1.04
P40 llama 8B Q4_0 Yes 1 pp512 56.44 58.36 1.03
P40 llama 8B Q4_0 Yes 2 pp512 111.77 113.15 1.01
P40 llama 8B Q4_0 Yes 3 pp512 154.34 153.60 1.00
P40 llama 8B Q4_0 Yes 4 pp512 159.86 156.38 0.98
P40 llama 8B Q4_0 Yes 5 pp512 189.98 194.64 1.02
P40 llama 8B Q4_0 Yes 6 pp512 203.27 203.96 1.00
P40 llama 8B Q4_0 Yes 7 pp512 194.65 201.68 1.04
P40 llama 8B Q4_0 Yes 8 pp512 200.50 211.52 1.05
RTX 3090 llama 1B Q4_0 Yes 1 pp512 665.98 679.69 1.02
RTX 3090 llama 1B Q4_0 Yes 2 pp512 1102.97 1127.48 1.02
RTX 3090 llama 1B Q4_0 Yes 3 pp512 1542.74 1565.06 1.01
RTX 3090 llama 1B Q4_0 Yes 4 pp512 1929.88 1902.27 0.99
RTX 3090 llama 1B Q4_0 Yes 5 pp512 2092.90 2079.23 0.99
RTX 3090 llama 1B Q4_0 Yes 6 pp512 2241.82 2306.79 1.03
RTX 3090 llama 1B Q4_0 Yes 7 pp512 2449.68 2448.28 1.00
RTX 3090 llama 1B Q4_0 Yes 8 pp512 2556.59 2655.01 1.04
RTX 3090 llama 8B Q4_0 Yes 1 pp512 155.02 156.56 1.01
RTX 3090 llama 8B Q4_0 Yes 2 pp512 277.36 279.89 1.01
RTX 3090 llama 8B Q4_0 Yes 3 pp512 390.23 390.96 1.00
RTX 3090 llama 8B Q4_0 Yes 4 pp512 468.95 470.48 1.00
RTX 3090 llama 8B Q4_0 Yes 5 pp512 514.50 509.70 0.99
RTX 3090 llama 8B Q4_0 Yes 6 pp512 529.37 541.57 1.02
RTX 3090 llama 8B Q4_0 Yes 7 pp512 559.17 559.30 1.00
RTX 3090 llama 8B Q4_0 Yes 8 pp512 570.67 583.77 1.02
RTX 4090 llama 1B Q4_0 Yes 1 pp512 879.19 891.23 1.01
RTX 4090 llama 1B Q4_0 Yes 2 pp512 1358.20 1362.47 1.00
RTX 4090 llama 1B Q4_0 Yes 3 pp512 1972.75 1974.98 1.00
RTX 4090 llama 1B Q4_0 Yes 4 pp512 2590.58 2610.67 1.01
RTX 4090 llama 1B Q4_0 Yes 5 pp512 2892.37 2929.62 1.01
RTX 4090 llama 1B Q4_0 Yes 6 pp512 3278.21 3323.54 1.01
RTX 4090 llama 1B Q4_0 Yes 7 pp512 3612.01 3629.45 1.00
RTX 4090 llama 1B Q4_0 Yes 8 pp512 3945.15 3960.70 1.00
RTX 4090 llama 8B Q4_0 Yes 1 pp512 189.43 190.74 1.01
RTX 4090 llama 8B Q4_0 Yes 2 pp512 334.86 336.98 1.01
RTX 4090 llama 8B Q4_0 Yes 3 pp512 494.97 497.43 1.00
RTX 4090 llama 8B Q4_0 Yes 4 pp512 655.17 658.22 1.00
RTX 4090 llama 8B Q4_0 Yes 5 pp512 768.36 776.39 1.01
RTX 4090 llama 8B Q4_0 Yes 6 pp512 899.40 899.70 1.00
RTX 4090 llama 8B Q4_0 Yes 7 pp512 1003.74 1004.35 1.00
RTX 4090 llama 8B Q4_0 Yes 8 pp512 1086.69 1091.22 1.00
RX 6800 llama 1B Q4_0 No 1 pp512 236.03 236.85 1.00
RX 6800 llama 1B Q4_0 No 2 pp512 451.63 453.87 1.00
RX 6800 llama 1B Q4_0 No 3 pp512 626.35 631.96 1.01
RX 6800 llama 1B Q4_0 No 4 pp512 758.88 758.65 1.00
RX 6800 llama 1B Q4_0 No 5 pp512 845.78 851.50 1.01
RX 6800 llama 1B Q4_0 No 6 pp512 919.45 912.94 0.99
RX 6800 llama 1B Q4_0 No 7 pp512 962.88 960.16 1.00
RX 6800 llama 1B Q4_0 No 8 pp512 1043.47 1010.94 0.97
RX 6800 llama 8B Q4_0 No 1 pp512 64.34 64.62 1.00
RX 6800 llama 8B Q4_0 No 2 pp512 121.92 121.55 1.00
RX 6800 llama 8B Q4_0 No 3 pp512 158.20 161.59 1.02
RX 6800 llama 8B Q4_0 No 4 pp512 177.57 175.08 0.99
RX 6800 llama 8B Q4_0 No 5 pp512 188.46 186.37 0.99
RX 6800 llama 8B Q4_0 No 6 pp512 196.97 194.63 0.99
RX 6800 llama 8B Q4_0 No 7 pp512 200.01 197.98 0.99
RX 6800 llama 8B Q4_0 No 8 pp512 198.43 200.38 1.01

fastdiv alone is consistently faster, the addition of __launch_bounds__ is faster on average but always for batch size 1 which is the most important use case.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 4, 2025
Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initial efforts! Spotted two more instances where we could apply fastdiv/fastmodulo as well if I'm not mistaken.

const int sample_y = sample_dst;
const uint32_t channel_dst = blockIdx.y;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tried applying fastmodulo to nchannels_y?

const uint32_t ncols_x, const uint32_t nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {

constexpr int qk = ggml_cuda_type_traits<type>::qk;
Copy link
Contributor

@ORippler ORippler Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qk, qi, vdr and QK8_1 are all compile time constants. We should therefore be able to replace the divisons and modulo used to determine kbx, kby and kqs in the first for loop with fastdiv/fastmodulo as well (lines 178-182)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are all powers of 2 though, shouldn't the compiler be able to replace the divisions/modulos with shifts/bitwise ands? Is the use of signed integers in this context problematic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are all powers of 2 though, shouldn't the compiler be able to replace the divisions/modulos with shifts/bitwise ands? Is the use of signed integers in this context problematic?

If they are indeed powers of 2 the compiler will do the correct optimizations for us at compile-time (see godbolt). While it requires 2 more instructions to do signed integer division/modulo compared to unsigned (see this stack-overflow post if interested in details), that should be negligible for our use-case

@JohannesGaessler
Copy link
Collaborator Author

I also enabled fastmodulo used in the MoE branch, thanks for pointing it out. I noticed that some extra care must be taken for values which are used in only one of the device code branches. Trying to call init_fastdiv_values for 0 results in an arithmetic exception in host code. I kept it as an error but added an assert.

Performance
GPU Model FlashAttention Microbatch size Test t/s master t/s 748c6a5 Speedup
MI60 / MI50 deepseek2 16B Q4_0 No 1 pp512 111.86 114.86 1.03
MI60 / MI50 deepseek2 16B Q4_0 No 2 pp512 154.03 153.79 1.00
MI60 / MI50 deepseek2 16B Q4_0 No 3 pp512 164.48 162.50 0.99
MI60 / MI50 deepseek2 16B Q4_0 No 4 pp512 200.42 199.34 0.99
MI60 / MI50 deepseek2 16B Q4_0 No 5 pp512 235.12 236.64 1.01
MI60 / MI50 deepseek2 16B Q4_0 No 6 pp512 222.68 221.66 1.00
MI60 / MI50 deepseek2 16B Q4_0 No 7 pp512 240.91 244.66 1.02
MI60 / MI50 deepseek2 16B Q4_0 No 8 pp512 274.86 275.63 1.00
P40 deepseek2 16B Q4_0 No 1 pp512 77.76 80.49 1.04
P40 deepseek2 16B Q4_0 No 2 pp512 117.58 118.81 1.01
P40 deepseek2 16B Q4_0 No 3 pp512 149.76 150.04 1.00
P40 deepseek2 16B Q4_0 No 4 pp512 167.27 166.13 0.99
P40 deepseek2 16B Q4_0 No 5 pp512 192.36 192.58 1.00
P40 deepseek2 16B Q4_0 No 6 pp512 208.26 208.44 1.00
P40 deepseek2 16B Q4_0 No 7 pp512 224.89 224.44 1.00
P40 deepseek2 16B Q4_0 No 8 pp512 232.20 233.23 1.00
RTX 3090 deepseek2 16B Q4_0 Yes 1 pp512 193.69 198.53 1.02
RTX 3090 deepseek2 16B Q4_0 Yes 2 pp512 190.24 189.92 1.00
RTX 3090 deepseek2 16B Q4_0 Yes 3 pp512 245.66 245.62 1.00
RTX 3090 deepseek2 16B Q4_0 Yes 4 pp512 301.49 297.95 0.99
RTX 3090 deepseek2 16B Q4_0 Yes 5 pp512 351.92 351.26 1.00
RTX 3090 deepseek2 16B Q4_0 Yes 6 pp512 399.18 402.66 1.01
RTX 3090 deepseek2 16B Q4_0 Yes 7 pp512 449.73 447.73 1.00
RTX 3090 deepseek2 16B Q4_0 Yes 8 pp512 488.72 487.75 1.00
RTX 4090 deepseek2 16B Q4_0 Yes 1 pp512 263.50 267.06 1.01
RTX 4090 deepseek2 16B Q4_0 Yes 2 pp512 283.48 285.52 1.01
RTX 4090 deepseek2 16B Q4_0 Yes 3 pp512 369.63 371.67 1.01
RTX 4090 deepseek2 16B Q4_0 Yes 4 pp512 457.25 458.50 1.00
RTX 4090 deepseek2 16B Q4_0 Yes 5 pp512 529.74 532.14 1.00
RTX 4090 deepseek2 16B Q4_0 Yes 6 pp512 606.53 610.54 1.01
RTX 4090 deepseek2 16B Q4_0 Yes 7 pp512 676.57 678.63 1.00
RTX 4090 deepseek2 16B Q4_0 Yes 8 pp512 746.90 748.84 1.00
RX 6800 deepseek2 16B Q4_0 No 1 pp512 81.20 81.93 1.01
RX 6800 deepseek2 16B Q4_0 No 2 pp512 102.97 103.31 1.00
RX 6800 deepseek2 16B Q4_0 No 3 pp512 136.40 136.92 1.00
RX 6800 deepseek2 16B Q4_0 No 4 pp512 163.75 164.06 1.00
RX 6800 deepseek2 16B Q4_0 No 5 pp512 181.78 181.82 1.00
RX 6800 deepseek2 16B Q4_0 No 6 pp512 202.20 203.18 1.00
RX 6800 deepseek2 16B Q4_0 No 7 pp512 219.51 219.02 1.00
RX 6800 deepseek2 16B Q4_0 No 8 pp512 235.59 234.33 0.99

Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort, looks good from my side!

@JohannesGaessler JohannesGaessler merged commit 5143fa8 into ggml-org:master Sep 5, 2025
48 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Sep 5, 2025
…g-model-disabled-agent-prefill

* origin/master: (84 commits)
CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (ggml-org#15802)
tests : add --list-ops and --show-coverage options (ggml-org#15745)
gguf: gguf_writer refactor (ggml-org#15691)
kv-cache : fix SWA checks + disable cacheless iSWA (ggml-org#15811)
model-conversion : add --embeddings flag to modelcard.template [no ci] (ggml-org#15801)
chat : fixed crash when Hermes 2 <tool_call> had a newline before it (ggml-org#15639)
chat : nemotron thinking & toolcalling support (ggml-org#15676)
scripts : add Jinja tester PySide6 simple app (ggml-org#15756)
llama : add support for EmbeddingGemma 300m (ggml-org#15798)
metal : Add template specialization for mul_mm_id w/ ne20 == 10 (ggml-org#15799)
llama : set n_outputs to 1 to avoid 0 outputs mean-pooling (ggml-org#15791)
CANN: Refactor ND to NZ workspace to be per-device (ggml-org#15763)
server: add exceed_context_size_error type (ggml-org#15780)
Document the new max GPU layers default in help (ggml-org#15771)
ggml: add ops for WAN video model (cuda && cpu) (ggml-org#15669)
CANN: Fix precision issue on 310I DUO multi-devices (ggml-org#15784)
opencl: add hs=40 to FA (ggml-org#15758)
CANN: fix acl_rstd allocation size in ggml_cann_rms_norm (ggml-org#15760)
vulkan: fix mmv subgroup16 selection (ggml-org#15775)
vulkan: don't use std::string in load_shaders, to improve compile time (ggml-org#15724)
...
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Sep 5, 2025
…upport

* origin/master:
Thinking model disabled assistant prefill (ggml-org#15404)
Implement --log-colors with always/never/auto (ggml-org#15792)
CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (ggml-org#15802)
tests : add --list-ops and --show-coverage options (ggml-org#15745)
gguf: gguf_writer refactor (ggml-org#15691)
kv-cache : fix SWA checks + disable cacheless iSWA (ggml-org#15811)
model-conversion : add --embeddings flag to modelcard.template [no ci] (ggml-org#15801)
chat : fixed crash when Hermes 2 <tool_call> had a newline before it (ggml-org#15639)
chat : nemotron thinking & toolcalling support (ggml-org#15676)
scripts : add Jinja tester PySide6 simple app (ggml-org#15756)
llama : add support for EmbeddingGemma 300m (ggml-org#15798)
walidbr pushed a commit to walidbr/llama.cpp that referenced this pull request Sep 7, 2025
* CUDA: fastdiv, launch bounds for mmvq + q8_1 quant
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants