Question: KV Cache Size in GGML_FLASH_ATTN_EXT Operator #13816
-
I've been tracing the variable
I have a hypothesis about this: could it be to facilitate parallel processing? However, if that's the case, this approach would primarily benefit GPUs. So, would this get bad performance on the CPU? Could someone clarify why the KV cache input for FLASH_ATTN_EXT appears to be aligned to 256 instead of growing with the decode sequence length? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
The KV cache is padded here: llama.cpp/src/llama-kv-cache.cpp Lines 19 to 24 in 4f81b33 The main reason is to be able to write more efficient GPU kernels that don't have to worry about out-of-bounds access. You are correct that this is technically an overhead for the CPU since it uses a scalar implementation. But in the future, the FA CPU implementation could be vectorized and also benefit from this padding. Also this padding is a necessary step for reusing compute graphs during generation. |
Beta Was this translation helpful? Give feedback.
-
I generally understand the purpose of this 256 alignment, But based on the code below, the graph still be built multiple times on the CPU side? llama.cpp/src/llama-context.cpp Lines 977 to 985 in 1e8659e Furthermore, will this typically introduce a high performance overhead? |
Beta Was this translation helpful? Give feedback.
Yes, the graph is currently constructed for every ubatch and this introduces some overhead.
The padding is a necessary step to be able to reuse the graphs, but it's not everything. Currently, there is a dynamic offset
head
in the KV cache that changes for every ubatch:llama.cpp/src/llama-kv-cache.cpp
Lines 542 to 554 in 05f6ac6