-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[Feat] Support FlashMLA backend with MTP and FP8 KV cache #6109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9550a0a
to
f7227c9
Compare
Hi @quinnrong94 , can you take a look at this CI fail? |
Hi @Fridge003 , I saw flashMLA test failed in CI, I wonder if it's due to the same reason with #5587 ? |
@quinnrong94 Let us rerun the CI for you, no need to rebase. thanks |
It seems to be due to bug caused by this PR. By now it should be fixed. |
For Future PRs:
|
|
Hi @quinnrong94, I have a question, Can FlashMLA be used with mtp 3-1-4? I test on H20 141g, and the benchmark reported a memory leak at the end: Scheduler hit an exception: Traceback (most recent call last): My sglang start parameter is: python3 -m sglang.launch_server --model-path $R1_MODEL_PATH --tp $TP --trust-remote-code --port $PORT --host 0.0.0.0 --mem-fraction-static 0.85 --max-running-requests $max_running_requests --disable-radix-cache --attention-backend flashmla --speculative-algorithm NEXTN --speculative-draft $NextN_MODEL_PATH --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --kv-cache-dtype fp8_e4m3 |
|
@mahaocong90 MTP 3,1,4 is fixed by neiltian-tencent@0987047 |
@neiltian-tencent Does it work with flashinfer MLA? I merged it into main branch and test using flashmla, but it still reports mem leak. |
@mahaocong90 The bugs of MTP 3, 1, 4 have been fixed, the memory leak will be fixed later。 |
…t#6109) Co-authored-by: Yingyi <[email protected]> Co-authored-by: neiltian <[email protected]> Co-authored-by: lukec <[email protected]> Co-authored-by: kexueyu <[email protected]> Co-authored-by: vincentmeng <[email protected]> Co-authored-by: pengmeng <[email protected]>
…t#6109) Co-authored-by: Yingyi <[email protected]> Co-authored-by: neiltian <[email protected]> Co-authored-by: lukec <[email protected]> Co-authored-by: kexueyu <[email protected]> Co-authored-by: vincentmeng <[email protected]> Co-authored-by: pengmeng <[email protected]>
Motivation
This PR improves flashmla backend by accelerating decode stage with mtp. The implementation utilizes the feature that flashmla can handle
seq_len_q > 1
. To use flashmla with mtp, an example server args can be--attention-backend flashmla --speculative-algorithm NEXTN --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2
.This PR also supports flashmla backend with fp8 kv-cache, which halves kv cache usage and enables larger concurrency with longer input sequence lengths when memory is limited. To enable fp8 kv-cache, an additional server arg need to be added:
--kv-cache-dtype fp8_e4m3
.The speedup of MTP + FP8 KV cache is about 30% with KV cache usage reducing by 50%:
Modifications
Checklist