Skip to content

Commit a6f892e

Browse files
authored
Revert "Avoid computing lse in Ragged Prefill when there's no prefix.… (#5544)
1 parent 08b518d commit a6f892e

File tree

3 files changed

+12
-19
lines changed

3 files changed

+12
-19
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
192192
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
193193
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
194194
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
195-
* `flashinfer_mla_disable_ragged`: Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend.
195+
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
196196
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -428,25 +428,18 @@ def forward_extend(
428428
v_scale=v_scale,
429429
)
430430
else:
431+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
432+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
433+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
434+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
435+
causal=True,
436+
sm_scale=layer.scaling,
437+
logits_soft_cap=logits_soft_cap,
438+
)
439+
431440
if self.forward_metadata.extend_no_prefix:
432-
o = prefill_wrapper_paged.forward(
433-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
434-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
435-
causal=not layer.is_cross_attention,
436-
sm_scale=layer.scaling,
437-
logits_soft_cap=logits_soft_cap,
438-
k_scale=k_scale,
439-
v_scale=v_scale,
440-
)
441+
o = o1
441442
else:
442-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
443-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
444-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
445-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
446-
causal=True,
447-
sm_scale=layer.scaling,
448-
logits_soft_cap=logits_soft_cap,
449-
)
450443
o2, s2 = prefill_wrapper_paged.forward_return_lse(
451444
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
452445
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),

python/sglang/srt/layers/attention/flashinfer_mla_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def forward_extend(
348348

349349
if self.forward_metadata.use_ragged:
350350
# ragged prefill
351-
o = self.prefill_wrapper_ragged.forward(
351+
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
352352
qall,
353353
k.view(-1, layer.tp_k_head_num, layer.head_dim),
354354
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),

0 commit comments

Comments
 (0)