Hip fattn expf approx#23441
Open
a-huk wants to merge 2 commits into
Open
Conversation
Replace expf() with __expf() in the softmax rescaling loops of fattn-mma-f16.cuh, fattn-tile.cuh, and fattn-vec.cuh. __expf is the hardware fast-path approximation (~4x faster than the IEEE-754 expf on both CUDA SFUs and AMD hardware). The accuracy difference (~1-2 ULP) is irrelevant in the softmax context: the subtracted max already bounds the input range, and LLM logit distributions are not sensitive to this precision level. Affected call sites (all in the per-tile softmax rescaling loop): - fattn-mma-f16.cuh: 6 sites in flash_attn_ext_f16_iter and flash_attn_ext_f16_process_tile - fattn-tile.cuh: 4 sites - fattn-vec.cuh: 5 sites
Three improvements to the RDNA3 WMMA flash attention path on gfx1100/gfx1101/gfx1151:
1. Fix register spill in RDNA3 configs (fattn-mma-f16.cuh):
- DKQ=64, ncols=32,64: Q_in_reg true→false (eliminates 8-44 byte scratch spill)
- DKQ=80-256: Q_in_reg true→false (reduces 52-720 byte spill to <36 bytes)
2. Restrict RDNA3 WMMA dispatch to DKQ=64 (fattn.cu):
- DKQ=80-128: 320-480+ bytes scratch remain due to VKQ accumulator pressure
in the WMMA tile loop, even with Q_in_reg=false. Excluded until the inner
loop can be restructured to fit within gfx1151's 256 VGPR hard cap.
- DKQ=256: no throughput benefit on gfx1151 (Q_in_reg=false forces nbatch_fa=32,
giving 1024 inner iterations at 32K context — overhead dominates).
3. Replace __expf with exp2f in MMA softmax (fattn-mma-f16.cuh):
- Q values are pre-scaled by log2(e) on load into shared memory, converting the
softmax from base-e to base-2. All exp() -> exp2() with no approximation:
exp(x) == exp2(x * log2(e)) exactly.
- exp2f is hardware-native on AMD RDNA (single v_exp_f32 instruction). Captures
essentially 100% of the theoretical gain from removing exp overhead, which
accounts for ~18% of FA compute at 32K context.
- Attention sinks (sinks_f) scaled by log2(e) to stay consistent with the
shifted KQ_max tracking space.
Benchmarks on Radeon 8060S (gfx1151, 40 CUs), Qwen3.5-27B Q4_K_M:
| context | FA=0 (rocBLAS) | FA=1 (before) | FA=1 (after) | delta |
|---------|----------------|---------------|--------------|--------|
| pp512 | 310 t/s | ~308 | 323 t/s | +5% |
| pp4096 | 292 t/s | 304 | 310 t/s | +2% |
| pp8192 | 284 t/s | 281 | 301 t/s | +7% |
| pp32768 | 204 t/s | 196 | 239 t/s | +22% |
FA=1 now outperforms FA=0 at all tested context lengths.
1 task
|
Hi @a-huk, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
Contributor
|
As previously suggested, try Also according to the llama.cpp AI usage policy:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
Two fixes for flash attention on AMD RDNA3 GPUs (gfx1100/gfx1101/gfx1151):
1. Fix RDNA3 MMA dispatch to prevent register-spill regression
The AMD WMMA flash-attention path was dispatched for all
head_dim ≤ 128, but onlyhead_dim = 64configs fit within the 256-VGPR wavefront budget on RDNA3 wave32. Configs withhead_dim = 80–128require 320–480+ bytes of scratch memory (91–114 spilled VGPRs forhead_dim = 128), making them slower than the non-MMA tile path.This PR tightens the dispatch guard from
Q->ne[0] <= 128toQ->ne[0] == 64, preventing regression on models withhead_dim = 80–128(Llama 3.x, Mistral, Phi, etc.). It also setsQ_in_reg = falsefor thehead_dim = 64,ncols = 32/64configs, which were borderline at 256 VGPRs exactly, this saves 12 VGPRs and removes their minor scratch usage.2. Replace
expf/__expfwith hardware-nativeexp2fin all flash attention kernelsAll three FA kernel implementations (
fattn-mma-f16.cuh,fattn-tile.cuh,fattn-vec.cuh) usedexpfor__expfin the softmax inner loop. On AMD hardwareexp2fis native (single instruction), whileexpfrequires a multi-step approximation.The substitution is exact:
exp(x) = exp2(x · log₂e). Q values are pre-scaled bylog₂eon load so all subsequent KQ dot products are already in base-2 space, and the softmaxexp(x − m)calls becomeexp2(x − m)with no approximation error. Rescaling factors for the running-max update are adjusted to match.The change is applied identically to all three kernels. On CUDA the difference is negligible (both paths use SFU hardware); on AMD it removes multi-step emulation from the softmax hotpath.
Additional information
VGPR budget analysis — RDNA3 wave32 (256-VGPR hard cap)
Q_in_reg=trueQ_in_reg=false(this PR)The inner loop body alone (KQ_C tiles, K/V load temporaries, WMMA compiler-allocated registers) accounts for ~300 of the ~370 VGPRs needed for
head_dim = 128. Fixinghead_dim ≥ 80requires kernel restructuring and is tracked separately.exp2f — benchmark (gfx1151, Qwen3.5-27B Q4_K_M)
Measured with
llama-bench -p <N> -n 0 -fa 0 -fa 1. FA=1 here uses the TILE path (DKQ=256 is not MMA-dispatched); the exp2f gains shown below apply to the TILE and VEC kernels on any GPU.The gain scales with context length because the softmax loop runs once per KV tile — more tiles means more exp calls. The 99.5% capture of the theoretical maximum (measured by temporarily removing all exp calls) confirms that
exp2fis essentially free on this hardware.MMA dispatch fix — benchmark (gfx1151, Qwen3-0.6B BF16, head_dim=64, GQA ratio=2)
Related: issue #21284 (gfx1151 inefficient defaults).
Requirements