Optimized flash attention (FA) for OpenCL backend, and add Q4/Q8 KV cache quantization with FA for Adreno GPUs. #23501
Optimized flash attention (FA) for OpenCL backend, and add Q4/Q8 KV cache quantization with FA for Adreno GPUs. #23501wanghqc wants to merge 10 commits into
Conversation
Rework the flash-attention kernels for the Adreno backend: a fused single-pass kernel with online softmax, GQA broadcast, and an N_SPLIT option that shares each query's DK/DV reduction across several threads to cut per-thread register pressure at large head dimensions. The QK reduction uses subgroup shuffle where the device supports it and falls back to a local-memory tree otherwise. flash_attn_f32_f16 is the mixed f32-query / f16-KV kernel for the common f16 KV-cache case; flash_attn_f16 / flash_attn_f32 cover the symmetric f16 and f32 paths. Each file provides a prefill kernel and a single-query (q1) decode kernel.
Add flash_attn_pre_f16.cl with three prepass kernels run before the
main attention kernel on the prefill path:
- flash_attn_kv_pad_f16 pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16 pads the matching mask tile
- flash_attn_blk_f16 classifies each KV tile per query block as
fully masked / mixed / fully unmasked, so
the main kernel can skip fully-masked tiles
and the mask lookup for fully-unmasked ones
Add flash_attn_f32_q8_0.cl and flash_attn_f32_q4_0.cl: flash-attention kernels that read a q8_0 / q4_0 quantised KV cache directly instead of dequantising it to f16/f32 first. Each file provides a prefill kernel, a single-query (q1) decode kernel, and N_SPLIT variants, mirroring the f32_f16 kernel structure.
Add kernel_set_rows_q8_0_* and kernel_set_rows_q4_0_* to set_rows.cl: set_rows variants that quantise f32 source rows into a q8_0 / q4_0 destination. These let the KV cache be stored quantised so the native quant flash-attention kernels can read it directly.
Add view-aware AoS dequant kernels to cvt.cl that dequantise a q8_0 / q4_0 tensor to f32/f16 while honouring the source tensor's view offset and row/slice strides. Flash attention uses these to dequantise a quantised KV cache on the GPU for the f16/f32 fallback path, instead of round-tripping the cache through host memory.
Add fa_tune.h: the per-(dk,dv) flash-attention tile configuration (BLOCK_M, BLOCK_N, N_SPLIT, split threshold) used by both the FA dispatch and supports_op. The table is kept in its own header so the tuning numbers are easy to find and edit separately from the dispatch logic. GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,...] patches matching entries at backend init, before the first FA kernel compiles, for on-device tuning without a rebuild.
Wire the flash-attention kernels into the OpenCL backend:
- Per-(dk,dv,variant) lazy kernel compile (ggml_opencl_ensure_fa_*),
so only the FA variants a model actually uses are built, keeping
the Adreno host-memory footprint and startup cost down. Optional
variants compile non-fatally (build_program_from_source_ex) and
the dispatch falls back when one is unavailable.
- ggml_cl_flash_attn dispatch: prefill prepass (KV-pad, mask-pad,
per-tile mask classification), flash-decoding K-split for the
single-query path, native q8_0/q4_0 KV, and a GPU-dequant path
for asymmetric / unsupported KV types.
- supports_op for GGML_OP_FLASH_ATTN_EXT against the tuning table.
- GGML_OPENCL_FA_LOG_SPILL diagnostic to report per-kernel private
memory (Adreno register spill) at compile time.
Also follow view_src to the SoA-converted extra in mul_mat_id, which
otherwise reads a view tensor's pre-SoA extra.
kernel_exp_f16 / kernel_expm1_f16 evaluated exp() directly in half precision, which loses accuracy for inputs of moderate magnitude. Convert to f32, evaluate exp() / expm1() there, and convert the result back to half. Note: this is a general half-precision fix, independent of the flash-attention work in this series.
The use_adreno_moe_kernels q4_0 SoA-conversion path set tensor->extra to a ggml_tensor_extra_cl_q4_0 but skipped the q4_0_soa_tensors.insert that the non-MoE conversion path performs. ggml_cl_is_q4_0_soa() keys off that set, so MoE expert weights were reported as AoS. ggml_backend_opencl_buffer_get_tensor then reinterpreted the SoA extra as a plain AoS extra, yielding a garbage buffer offset and a clEnqueueReadBuffer CL_INVALID_VALUE abort when test-backend-ops copies the graph back to the CPU backend.
The one-pass online-softmax FA kernels initialised the running max m_i = -INFINITY and then evaluated exp(m_i - m_new) = exp(-INFINITY) on the first loop iteration. These kernels build with -cl-finite-math-only, under which an infinite operand to exp() is undefined behaviour: the Adreno X2 driver returns 0 (so it works by luck), but the Adreno X1 driver returns NaN, which propagates to all-NaN attention output and garbage tokens. Replace the -INFINITY running-max init -- and the masked-score and flash-decoding merge sentinels -- with a finite FA_M_INIT (-3e38): below any real attention score, with headroom under FLT_MAX so (m_i - score) cannot overflow, and well-defined under -cl-finite-math-only. No GPU-generation gate -- it removes real UB and is correct on X1 and X2. Found by an Adreno X1-85 validation run; X2 test-backend-ops FLASH_ATTN_EXT remains 2564/2564.
|
Hi @wanghqc, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
Overview
Optimized the flash attention for the OpenCL backend, targeting the Qualcomm Adreno X1 and X2 GPUs for Windows on Snapdragon (WoS) devices. In addition to optimizations for prefill, this PR also added the support of the Q4/Q8 KV cache quantization for token generation, demonstrating very good performance and memory efficiency with long context for selective models, e.g., the OpenAI GPT-OSS-20B model.
Additional information
Fused FA kernels —
f16,f32, and mixedf32-query/f16-KV, each with a prefill kernel and a single-query (q1) decode kernel. Single-pass online softmax, GQA broadcast, and anN_SPLIToption that shares each query's DK/DV reduction across threads to cut per-thread register pressure at large head dims. The QK reduction uses subgroup shuffle where available and falls back to a local-memory tree otherwise.Native q8_0/q4_0 KV-cache kernels — read a quantised KV cache directly instead of dequantising it first, plus
set_rowsvariants that quantize the cache on write and view-aware GPU dequant kernels for the fallback path.Prefill prepass — pads the tail KV/mask tiles and classifies each KV tile per query block (fully-masked / mixed / fully-unmasked) so the main kernel can skip fully-masked tiles and the mask lookup for unmasked ones.
Flash-decoding K-split for the single-query decode path.
Per-(dk,dv,variant) lazy kernel compile — only the FA variants a model actually uses are built, keeping Adreno host-memory footprint and startup cost down: note that we will add more and different tiers which may have different optimal parameters for the FA kernels.
Tile tuning table (
fa_tune.h) with aGGML_OPENCL_FA_TUNEenv override for on-device tuning without a rebuild.Hardware
Developed and tested on a commericial Adreno X2-90 device (Asus) and a commercial Adreno X1 device (HP). Should run on other Adreno 7xx/8xx-class and OpenCL devices: to be confirmed.
Benchmark examples
gpt-oss-20b-MXFP4, Adreno X2-90,llama-bench:For this model, the long-context decode roughly doubles; FA + quantised KV also cuts KV-cache memory enough to fit longer contexts.
Also passed functionality test with small models that can run on the Qualcomm X1 devices with 16GB DDR.
Testing
test-backend-ops -o FLASH_ATTN_EXT— all supported (dk,dv)/type combinations pass against the CPU reference (f16, f32, q8_0/q4_0 KV, symmetric and asymmetric, masked and with attention sinks).Build / usage
-fa 1enables it; the KV cache may bef16,q8_0, orq4_0.GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,...]for tuning,GGML_OPENCL_FA_LOG_SPILL=1for a per-kernel register-spill diagnostic.Commit structure
Nine focused commits — kernels, prepass, native quant KV, quantising set_rows, GPU dequant, tuning table, host integration — for incremental review.
##Known issues:
Requirements