Skip to content

Commit 25e3392

Browse files
lifuhuangtarinkk
authored andcommitted
FA3 speed up: skip len operation and get batch size directly from forward batch (sgl-project#5969)
Signed-off-by: Lifu Huang <[email protected]>
1 parent 1d0edc2 commit 25e3392

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
342342
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
343343
metadata = FlashAttentionMetadata()
344344
seqlens_in_batch = forward_batch.seq_lens
345-
batch_size = len(seqlens_in_batch)
345+
batch_size = forward_batch.batch_size
346346
device = seqlens_in_batch.device
347347

348348
if forward_batch.forward_mode.is_decode_or_idle():

0 commit comments

Comments
 (0)