@@ -444,6 +444,15 @@ def process_prebuilt_extend(
444
444
445
445
class SchedulerDisaggregationDecodeMixin :
446
446
447
+ def _prepare_idle_batch_and_run (self , batch , delay_process = False ):
448
+ batch , _ = self .prepare_dp_attn_batch (batch )
449
+ result = None
450
+ if batch :
451
+ result = self .run_batch (batch )
452
+ if not delay_process :
453
+ self .process_batch_result (batch , result )
454
+ return batch , result
455
+
447
456
@torch .no_grad ()
448
457
def event_loop_normal_disagg_decode (self ):
449
458
"""A normal scheduler loop for decode worker in disaggregation mode."""
@@ -456,14 +465,25 @@ def event_loop_normal_disagg_decode(self):
456
465
batch = self .get_next_disagg_decode_batch_to_run ()
457
466
self .cur_batch = batch
458
467
468
+ prepare_dp_attn_flag = (
469
+ self .server_args .enable_dp_attention
470
+ or self .server_args .enable_sp_layernorm
471
+ )
472
+
459
473
if batch :
460
474
# Generate fake extend output.
461
475
if batch .forward_mode .is_extend ():
462
476
# Note: Logprobs should be handled on the prefill engine.
463
477
self .stream_output (batch .reqs , False )
478
+ if prepare_dp_attn_flag :
479
+ self ._prepare_idle_batch_and_run (None )
464
480
else :
481
+ if prepare_dp_attn_flag :
482
+ self .prepare_dp_attn_batch (batch )
465
483
result = self .run_batch (batch )
466
484
self .process_batch_result (batch , result )
485
+ elif prepare_dp_attn_flag :
486
+ batch , _ = self ._prepare_idle_batch_and_run (None )
467
487
468
488
if batch is None and (
469
489
len (self .disagg_decode_transfer_queue .queue )
@@ -480,7 +500,7 @@ def event_loop_normal_disagg_decode(self):
480
500
def event_loop_overlap_disagg_decode (self ):
481
501
result_queue = deque ()
482
502
self .last_batch : Optional [ScheduleBatch ] = None
483
- self .last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
503
+ self .last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
484
504
485
505
while True :
486
506
recv_reqs = self .recv_requests ()
@@ -489,20 +509,41 @@ def event_loop_overlap_disagg_decode(self):
489
509
self .process_decode_queue ()
490
510
batch = self .get_next_disagg_decode_batch_to_run ()
491
511
self .cur_batch = batch
492
- last_batch_is_extend = False
512
+ last_batch_in_queue = False
513
+
514
+ prepare_dp_attn_flag = (
515
+ self .server_args .enable_dp_attention
516
+ or self .server_args .enable_sp_layernorm
517
+ )
493
518
494
519
if batch :
495
520
# Generate fake extend output.
496
521
if batch .forward_mode .is_extend ():
497
522
# Note: Logprobs should be handled on the prefill engine.
498
523
self .stream_output (batch .reqs , False )
499
- last_batch_is_extend = True
524
+ if prepare_dp_attn_flag :
525
+ batch_ , result = self ._prepare_idle_batch_and_run (
526
+ None , delay_process = True
527
+ )
528
+ if batch_ :
529
+ result_queue .append ((batch_ .copy (), result ))
530
+ last_batch_in_queue = True
500
531
else :
532
+ if prepare_dp_attn_flag :
533
+ self .prepare_dp_attn_batch (batch )
501
534
result = self .run_batch (batch )
502
535
result_queue .append ((batch .copy (), result ))
536
+ last_batch_in_queue = True
537
+ elif prepare_dp_attn_flag :
538
+ batch , result = self ._prepare_idle_batch_and_run (
539
+ None , delay_process = True
540
+ )
541
+ if batch :
542
+ result_queue .append ((batch .copy (), result ))
543
+ last_batch_in_queue = True
503
544
504
545
# Process the results of the previous batch but skip if the last batch is extend
505
- if self .last_batch and not self .last_batch_is_extend :
546
+ if self .last_batch and self .last_batch_in_queue :
506
547
tmp_batch , tmp_result = result_queue .popleft ()
507
548
self .process_batch_result (tmp_batch , tmp_result )
508
549
@@ -516,7 +557,7 @@ def event_loop_overlap_disagg_decode(self):
516
557
self .new_token_ratio = self .init_new_token_ratio
517
558
518
559
self .last_batch = batch
519
- self .last_batch_is_extend = last_batch_is_extend
560
+ self .last_batch_in_queue = last_batch_in_queue
520
561
521
562
def get_next_disagg_decode_batch_to_run (
522
563
self : Scheduler ,
0 commit comments