TL;DR There's some prefill performance to be eked out for llama.cpp running on AMD Strix Halo. The changes are relatively straightforward, thankfully, but there isn't a clear "best" approach to implementing fixes (this is just
what I did), so I'll just report this as an issue.
Pretext: The following is looking to identify inefficiencies without being prescriptive on how to solve them. I'll share what I did, but I'm not saying it's the best solution. Happy to send PRs if the changes look good to maintainers, but again, I'm happy to defer to others for addressing an objectively large opportunity for this hardware.
AI disclosure: Yes, for formatting, obviously (in fact, using llama.cpp). The writing is my own, but like most people, I don't remember the markdown format for generating pretty text ;) But really, please read the issue as it is written by me, regardless of how you feel about the formatting. At least it's not a wall of text!
Based on rocprof, I was looking to optimize prefill performance on Strix Halo. I focused on Qwen3.5 MoE models at Q4_K quantization because they're popular and amenable to performance optimizations. The resulting changes have been tested by a few different individuals in the Strix Halo community to verify the results.
Overall improvement measured via llama-bench (with matching parameters) shows ~20% uplift on 122B model which really isn't subtle :) (subject to testing parameters).
I'll summarize the areas of inefficiencies below.
1. VGPR pressure from settings in MMQ
There are three main parameters impacting this:
mmq_x
mmq_y
nwarps
All three are at suboptimal values for gfx1151. We can trace the code paths to infer how the values get set, but the more interesting detail is answering "what should the values be?" Static settings won't cover all hardware options, but for gfx1151, a x=48, y=64, nwarps=4 provides a meaningful improvement. It's possible that further improvements are possible, including maybe via __launch_bounds__, or other values (although it's not evident if there's a better default than this proposal). However, the main point is that objectively, the current HEAD likely spills and exceeds the 256 VGPR (not all of which are usable anyway).
2. Leveraging intrinsics
Note: (a) and (b) are not specific to gfx1151, but the optimizations makes a noticeable difference on this hardware.
-
(a) expf() → __expf(): __expf is known to be faster but less accurate. In MoE routing (various lines in gated_delta_net.cu) and in SiLU activation (ggml_cuda_op_silu_single), the loss in precision does not appear to make any meaningful difference. Larger losses in precision exist elsewhere, and the speed trade-off from the intrinsic seems to be virtually free for the Qwen 3.5 35B and 122B models.
-
(b) roundf() → __float2int_rn: quantize.cu executes rounding often enough that the intrinsic adds a noticeable benefit.
-
(c) ggml_cuda_dp4a in common.cuh: Use __builtin_amdgcn_sudot4 for RDNA 3.5.
3. Loop-invariant clean-up
Note: This is also not specific to gfx1151, and it's likely a meaningful code cleanliness change that also happens to have uplift in performance.
At the risk of being prescriptive: consider declaring loop-invariant constants (base0, base1, baseD below) outside of the for-loop in concat.cu:__launch_bounds__. Given that this is the slow path, it appears to help clean up the code. What I don't know is why this non-contiguous kernel path is taken in the first place, so there might be underlying issues that I haven't figured out.
const uint64_t base0 = (uint64_t)(i3)*nb03 + (uint64_t)(i2)*nb02 + (uint64_t)(i1)*nb01;
const uint64_t base1 = (uint64_t)(i3)*nb13 + (uint64_t)(i2)*nb12 + (uint64_t)(i1)*nb11;
const uint64_t baseD = (uint64_t)(i3)*nb3 + (uint64_t)(i2)*nb2 + (uint64_t)(i1)*nb1;
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
x = (const float *)(src0 + base0 + (uint64_t)(i0)*nb00);
y = (float *)(dst + baseD + (uint64_t)(i0)*nb0);
}
Changes described as a git patch: https://gist.github.com/pedapudi/183f337e687630a43eacb293e157c9bd
AI disclosure: Yes, obviously (in fact, using llama.cpp). The writing is my own, but like most people, I don't remember the markdown format for generating pretty text ;) But really, please read the issue as it is written by me, regardless of how you feel about the formatting. At least it's not a wall of text!
$ ./bin/llama-cli --version
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
version: 8627 (c30e01225)
built with GNU 15.2.0 for Linux x86_64
**BEFORE**
$ ./bin/llama-bench --model ~/models/unsloth/qwen35-35b/Qwen3.5-35B-A3B-Q4_K_M.gguf -p 128,256,512,1024,2048 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 584.73 ± 14.61 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 875.62 ± 7.62 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 1151.66 ± 8.74 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 1329.47 ± 9.85 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 1425.86 ± 10.62 |
build: c30e01225 (8627)
$ ./bin/llama-bench --model ~/models/unsloth/qwen35-122b-q4/unsloth_Qwen3.5-122B-A10B-GGUF_Q4_K_M_Qwen3.5-122B-A10B-Q4_K_M-00001-of-00003.gguf -p 128,256,512,1024,2048,4096 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 181.35 ± 4.84 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 271.93 ± 3.97 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 362.56 ± 1.89 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 422.23 ± 4.94 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 444.79 ± 3.85 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp4096 | 472.20 ± 1.18 |
build: c30e01225 (8627)
**AFTER**
$ ./bin/llama-bench --model ~/models/unsloth/qwen35-35b/Qwen3.5-35B-A3B-Q4_K_M.gguf -p 128,256,512,1024,2048 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 943.85 ± 14.24 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 1256.79 ± 5.48 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 1432.97 ± 5.34 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 1530.05 ± 4.45 |
| qwen35moe 35B.A3B Q4_K - Medium | 20.49 GiB | 34.66 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 1465.81 ± 6.33 |
build: c30e01225 (8627)
$ ./bin/llama-bench --model ~/models/unsloth/qwen35-122b-q4/unsloth_Qwen3.5-122B-A10B-GGUF_Q4_K_M_Qwen3.5-122B-A10B-Q4_K_M-00001-of-00003.gguf -p 128,256,512,1024,2048,4096 -n 0 --n-gpu-layers 99 --flash-attn 1 --mmap 0 --direct-io 1 --ubatch-size 2048 --batch-size 2048
ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB):
Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB
| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp128 | 315.48 ± 5.29 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp256 | 412.26 ± 3.34 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp512 | 497.50 ± 2.04 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp1024 | 500.45 ± 4.01 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp2048 | 566.55 ± 5.95 |
| qwen35moe 122B.A10B Q4_K - Medium | 71.27 GiB | 122.11 B | ROCm | 99 | 2048 | 1 | 0 | 1 | pp4096 | 501.66 ± 3.91 |
build: c30e01225 (8627)
TL;DR There's some prefill performance to be eked out for llama.cpp running on AMD Strix Halo. The changes are relatively straightforward, thankfully, but there isn't a clear "best" approach to implementing fixes (this is just
what I did), so I'll just report this as an issue.
Based on
rocprof, I was looking to optimize prefill performance on Strix Halo. I focused on Qwen3.5 MoE models at Q4_K quantization because they're popular and amenable to performance optimizations. The resulting changes have been tested by a few different individuals in the Strix Halo community to verify the results.Overall improvement measured via llama-bench (with matching parameters) shows ~20% uplift on 122B model which really isn't subtle :) (subject to testing parameters).
I'll summarize the areas of inefficiencies below.
1. VGPR pressure from settings in MMQ
There are three main parameters impacting this:
mmq_xmmq_ynwarpsAll three are at suboptimal values for
gfx1151. We can trace the code paths to infer how the values get set, but the more interesting detail is answering "what should the values be?" Static settings won't cover all hardware options, but forgfx1151, a x=48, y=64, nwarps=4 provides a meaningful improvement. It's possible that further improvements are possible, including maybe via__launch_bounds__, or other values (although it's not evident if there's a better default than this proposal). However, the main point is that objectively, the current HEAD likely spills and exceeds the 256 VGPR (not all of which are usable anyway).2. Leveraging intrinsics
(a)
expf()→__expf():__expfis known to be faster but less accurate. In MoE routing (various lines ingated_delta_net.cu) and in SiLU activation (ggml_cuda_op_silu_single), the loss in precision does not appear to make any meaningful difference. Larger losses in precision exist elsewhere, and the speed trade-off from the intrinsic seems to be virtually free for the Qwen 3.5 35B and 122B models.(b)
roundf()→__float2int_rn:quantize.cuexecutes rounding often enough that the intrinsic adds a noticeable benefit.(c)
ggml_cuda_dp4aincommon.cuh: Use__builtin_amdgcn_sudot4for RDNA 3.5.3. Loop-invariant clean-up
At the risk of being prescriptive: consider declaring loop-invariant constants (
base0,base1,baseDbelow) outside of the for-loop inconcat.cu:__launch_bounds__. Given that this is the slow path, it appears to help clean up the code. What I don't know is why this non-contiguous kernel path is taken in the first place, so there might be underlying issues that I haven't figured out.Changes described as a git patch: https://gist.github.com/pedapudi/183f337e687630a43eacb293e157c9bd
AI disclosure: Yes, obviously (in fact, using llama.cpp). The writing is my own, but like most people, I don't remember the markdown format for generating pretty text ;) But really, please read the issue as it is written by me, regardless of how you feel about the formatting. At least it's not a wall of text!
$ ./bin/llama-cli --version ggml_cuda_init: found 1 ROCm devices (Total VRAM: 126976 MiB): Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32, VRAM: 126976 MiB version: 8627 (c30e01225) built with GNU 15.2.0 for Linux x86_64