@@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay(
110
110
111
111
class TboCudaGraphRunnerPlugin :
112
112
def __init__ (self ):
113
- pass # TODO add logic here
113
+ self . _tbo_children_num_token_non_padded = torch . zeros (( 2 ,), dtype = torch . int32 )
114
114
115
115
def capture_one_batch_size (self , batch : ForwardBatch , num_tokens : int ):
116
116
if not global_server_args_dict ["enable_two_batch_overlap" ]:
@@ -124,15 +124,35 @@ def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
124
124
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
125
125
assert batch .tbo_split_seq_index is not None , f"{ num_tokens = } "
126
126
127
- TboForwardBatchPreparer .prepare (batch )
127
+ self ._tbo_children_num_token_non_padded [...] = (
128
+ TboForwardBatchPreparer .compute_tbo_children_num_token_non_padded (batch )
129
+ )
130
+
131
+ TboForwardBatchPreparer .prepare_raw (
132
+ batch ,
133
+ tbo_children_num_token_non_padded = self ._tbo_children_num_token_non_padded ,
134
+ )
128
135
129
136
def replay_prepare (
130
137
self , forward_mode : ForwardMode , bs : int , num_token_non_padded : int
131
138
):
132
139
if not global_server_args_dict ["enable_two_batch_overlap" ]:
133
140
return
134
141
135
- pass # TODO add logic here
142
+ tbo_split_seq_index , tbo_split_token_index = (
143
+ compute_split_indices_for_cuda_graph_replay (
144
+ forward_mode = forward_mode ,
145
+ # TODO support bs!=num_tokens
146
+ cuda_graph_num_tokens = bs ,
147
+ )
148
+ )
149
+
150
+ self ._tbo_children_num_token_non_padded [...] = (
151
+ TboForwardBatchPreparer .compute_tbo_children_num_token_non_padded_raw (
152
+ tbo_split_token_index = tbo_split_token_index ,
153
+ num_token_non_padded = num_token_non_padded ,
154
+ )
155
+ )
136
156
137
157
138
158
class TboDPAttentionPreparer :
@@ -207,16 +227,23 @@ def _is_all_same(x):
207
227
class TboForwardBatchPreparer :
208
228
@classmethod
209
229
def prepare (cls , batch : ForwardBatch ):
210
- from sglang .srt .layers .attention .tbo_backend import TboAttnBackend
211
-
212
230
if batch .tbo_split_seq_index is None :
213
231
return
214
232
215
- tbo_split_token_index = compute_split_token_index (
216
- split_seq_index = batch .tbo_split_seq_index ,
217
- forward_mode = batch .forward_mode ,
218
- extend_seq_lens = batch .extend_seq_lens_cpu ,
233
+ tbo_children_num_token_non_padded = (
234
+ cls .compute_tbo_children_num_token_non_padded (batch )
219
235
)
236
+ cls .prepare_raw (
237
+ batch , tbo_children_num_token_non_padded = tbo_children_num_token_non_padded
238
+ )
239
+
240
+ @classmethod
241
+ def prepare_raw (
242
+ cls , batch : ForwardBatch , tbo_children_num_token_non_padded : torch .Tensor
243
+ ):
244
+ from sglang .srt .layers .attention .tbo_backend import TboAttnBackend
245
+
246
+ tbo_split_token_index = cls ._compute_split_token_index (batch )
220
247
221
248
if _tbo_debug :
222
249
logger .info (
@@ -229,13 +256,18 @@ def prepare(cls, batch: ForwardBatch):
229
256
assert isinstance (batch .attn_backend , TboAttnBackend )
230
257
attn_backend_child_a , attn_backend_child_b = batch .attn_backend .children
231
258
259
+ [out_num_token_non_padded_a , out_num_token_non_padded_b ] = (
260
+ tbo_children_num_token_non_padded
261
+ )
262
+
232
263
child_a = cls .filter_batch (
233
264
batch ,
234
265
start_token_index = 0 ,
235
266
end_token_index = tbo_split_token_index ,
236
267
start_seq_index = 0 ,
237
268
end_seq_index = batch .tbo_split_seq_index ,
238
269
output_attn_backend = attn_backend_child_a ,
270
+ out_num_token_non_padded = out_num_token_non_padded_a ,
239
271
)
240
272
child_b = cls .filter_batch (
241
273
batch ,
@@ -244,6 +276,7 @@ def prepare(cls, batch: ForwardBatch):
244
276
start_seq_index = batch .tbo_split_seq_index ,
245
277
end_seq_index = batch .batch_size ,
246
278
output_attn_backend = attn_backend_child_b ,
279
+ out_num_token_non_padded = out_num_token_non_padded_b ,
247
280
)
248
281
249
282
assert batch .tbo_children is None
@@ -259,9 +292,8 @@ def filter_batch(
259
292
start_seq_index : int ,
260
293
end_seq_index : int ,
261
294
output_attn_backend : AttentionBackend ,
295
+ out_num_token_non_padded : torch .Tensor ,
262
296
):
263
- from sglang .srt .managers .schedule_batch import global_server_args_dict
264
-
265
297
num_tokens = batch .input_ids .shape [0 ]
266
298
num_seqs = batch .batch_size
267
299
@@ -342,6 +374,7 @@ def filter_batch(
342
374
),
343
375
extend_num_tokens = extend_num_tokens ,
344
376
attn_backend = output_attn_backend ,
377
+ num_token_non_padded = out_num_token_non_padded ,
345
378
tbo_split_seq_index = None ,
346
379
tbo_parent_token_range = (start_token_index , end_token_index ),
347
380
tbo_children = None ,
@@ -357,7 +390,6 @@ def filter_batch(
357
390
top_p_normalized_logprobs = False ,
358
391
top_p = None ,
359
392
mm_inputs = None ,
360
- num_token_non_padded = None ,
361
393
)
362
394
)
363
395
@@ -372,6 +404,32 @@ def filter_batch(
372
404
373
405
return ForwardBatch (** output_dict )
374
406
407
+ @classmethod
408
+ def compute_tbo_children_num_token_non_padded (cls , batch : ForwardBatch ):
409
+ return cls .compute_tbo_children_num_token_non_padded_raw (
410
+ tbo_split_token_index = cls ._compute_split_token_index (batch ),
411
+ num_token_non_padded = len (batch .input_ids ),
412
+ )
413
+
414
+ @classmethod
415
+ def compute_tbo_children_num_token_non_padded_raw (
416
+ cls , tbo_split_token_index : int , num_token_non_padded : int
417
+ ):
418
+ # TODO we may make padding on both sub-batches to make it slightly more balanced
419
+ value_a = min (tbo_split_token_index , num_token_non_padded )
420
+ value_b = max (0 , num_token_non_padded - tbo_split_token_index )
421
+ return torch .tensor ([value_a , value_b ], dtype = torch .int32 ).to (
422
+ device = global_server_args_dict ["device" ], non_blocking = True
423
+ )
424
+
425
+ @classmethod
426
+ def _compute_split_token_index (cls , batch : ForwardBatch ):
427
+ return compute_split_token_index (
428
+ split_seq_index = batch .tbo_split_seq_index ,
429
+ forward_mode = batch .forward_mode ,
430
+ extend_seq_lens = batch .extend_seq_lens_cpu ,
431
+ )
432
+
375
433
376
434
def _compute_extend_num_tokens (input_ids , forward_mode : ForwardMode ):
377
435
if forward_mode .is_extend ():
0 commit comments