Skip to content

Commit 31589e1

Browse files
fzyzcjych-wan
andauthored
Speed up when having padding tokens two-batch overlap (#6668)
Co-authored-by: Cheng Wan <[email protected]>
1 parent ae6a5b2 commit 31589e1

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def op_select_experts(self, state):
454454
num_expert_group=self.num_expert_group,
455455
correction_bias=self.correction_bias,
456456
routed_scaling_factor=self.routed_scaling_factor,
457+
num_token_non_padded=state.forward_batch.num_token_non_padded,
457458
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
458459
layer_id=self.layer_id,
459460
),

python/sglang/srt/two_batch_overlap.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay(
110110

111111
class TboCudaGraphRunnerPlugin:
112112
def __init__(self):
113-
pass # TODO add logic here
113+
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
114114

115115
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
116116
if not global_server_args_dict["enable_two_batch_overlap"]:
@@ -124,15 +124,35 @@ def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
124124
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
125125
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
126126

127-
TboForwardBatchPreparer.prepare(batch)
127+
self._tbo_children_num_token_non_padded[...] = (
128+
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
129+
)
130+
131+
TboForwardBatchPreparer.prepare_raw(
132+
batch,
133+
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
134+
)
128135

129136
def replay_prepare(
130137
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
131138
):
132139
if not global_server_args_dict["enable_two_batch_overlap"]:
133140
return
134141

135-
pass # TODO add logic here
142+
tbo_split_seq_index, tbo_split_token_index = (
143+
compute_split_indices_for_cuda_graph_replay(
144+
forward_mode=forward_mode,
145+
# TODO support bs!=num_tokens
146+
cuda_graph_num_tokens=bs,
147+
)
148+
)
149+
150+
self._tbo_children_num_token_non_padded[...] = (
151+
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
152+
tbo_split_token_index=tbo_split_token_index,
153+
num_token_non_padded=num_token_non_padded,
154+
)
155+
)
136156

137157

138158
class TboDPAttentionPreparer:
@@ -207,16 +227,23 @@ def _is_all_same(x):
207227
class TboForwardBatchPreparer:
208228
@classmethod
209229
def prepare(cls, batch: ForwardBatch):
210-
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
211-
212230
if batch.tbo_split_seq_index is None:
213231
return
214232

215-
tbo_split_token_index = compute_split_token_index(
216-
split_seq_index=batch.tbo_split_seq_index,
217-
forward_mode=batch.forward_mode,
218-
extend_seq_lens=batch.extend_seq_lens_cpu,
233+
tbo_children_num_token_non_padded = (
234+
cls.compute_tbo_children_num_token_non_padded(batch)
219235
)
236+
cls.prepare_raw(
237+
batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
238+
)
239+
240+
@classmethod
241+
def prepare_raw(
242+
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
243+
):
244+
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
245+
246+
tbo_split_token_index = cls._compute_split_token_index(batch)
220247

221248
if _tbo_debug:
222249
logger.info(
@@ -229,13 +256,18 @@ def prepare(cls, batch: ForwardBatch):
229256
assert isinstance(batch.attn_backend, TboAttnBackend)
230257
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children
231258

259+
[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
260+
tbo_children_num_token_non_padded
261+
)
262+
232263
child_a = cls.filter_batch(
233264
batch,
234265
start_token_index=0,
235266
end_token_index=tbo_split_token_index,
236267
start_seq_index=0,
237268
end_seq_index=batch.tbo_split_seq_index,
238269
output_attn_backend=attn_backend_child_a,
270+
out_num_token_non_padded=out_num_token_non_padded_a,
239271
)
240272
child_b = cls.filter_batch(
241273
batch,
@@ -244,6 +276,7 @@ def prepare(cls, batch: ForwardBatch):
244276
start_seq_index=batch.tbo_split_seq_index,
245277
end_seq_index=batch.batch_size,
246278
output_attn_backend=attn_backend_child_b,
279+
out_num_token_non_padded=out_num_token_non_padded_b,
247280
)
248281

249282
assert batch.tbo_children is None
@@ -259,9 +292,8 @@ def filter_batch(
259292
start_seq_index: int,
260293
end_seq_index: int,
261294
output_attn_backend: AttentionBackend,
295+
out_num_token_non_padded: torch.Tensor,
262296
):
263-
from sglang.srt.managers.schedule_batch import global_server_args_dict
264-
265297
num_tokens = batch.input_ids.shape[0]
266298
num_seqs = batch.batch_size
267299

@@ -342,6 +374,7 @@ def filter_batch(
342374
),
343375
extend_num_tokens=extend_num_tokens,
344376
attn_backend=output_attn_backend,
377+
num_token_non_padded=out_num_token_non_padded,
345378
tbo_split_seq_index=None,
346379
tbo_parent_token_range=(start_token_index, end_token_index),
347380
tbo_children=None,
@@ -357,7 +390,6 @@ def filter_batch(
357390
top_p_normalized_logprobs=False,
358391
top_p=None,
359392
mm_inputs=None,
360-
num_token_non_padded=None,
361393
)
362394
)
363395

@@ -372,6 +404,32 @@ def filter_batch(
372404

373405
return ForwardBatch(**output_dict)
374406

407+
@classmethod
408+
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
409+
return cls.compute_tbo_children_num_token_non_padded_raw(
410+
tbo_split_token_index=cls._compute_split_token_index(batch),
411+
num_token_non_padded=len(batch.input_ids),
412+
)
413+
414+
@classmethod
415+
def compute_tbo_children_num_token_non_padded_raw(
416+
cls, tbo_split_token_index: int, num_token_non_padded: int
417+
):
418+
# TODO we may make padding on both sub-batches to make it slightly more balanced
419+
value_a = min(tbo_split_token_index, num_token_non_padded)
420+
value_b = max(0, num_token_non_padded - tbo_split_token_index)
421+
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
422+
device=global_server_args_dict["device"], non_blocking=True
423+
)
424+
425+
@classmethod
426+
def _compute_split_token_index(cls, batch: ForwardBatch):
427+
return compute_split_token_index(
428+
split_seq_index=batch.tbo_split_seq_index,
429+
forward_mode=batch.forward_mode,
430+
extend_seq_lens=batch.extend_seq_lens_cpu,
431+
)
432+
375433

376434
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
377435
if forward_mode.is_extend():

0 commit comments

Comments
 (0)