Skip to content

Commit 711efe7

Browse files
ch-wanByronHsu
andauthored
Integrating PD disaggregation with DP attention and DeepEP (#5435)
Co-authored-by: Byron Hsu <[email protected]>
1 parent fbb5f22 commit 711efe7

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,15 @@ def process_prebuilt_extend(
444444

445445
class SchedulerDisaggregationDecodeMixin:
446446

447+
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
448+
batch, _ = self.prepare_dp_attn_batch(batch)
449+
result = None
450+
if batch:
451+
result = self.run_batch(batch)
452+
if not delay_process:
453+
self.process_batch_result(batch, result)
454+
return batch, result
455+
447456
@torch.no_grad()
448457
def event_loop_normal_disagg_decode(self):
449458
"""A normal scheduler loop for decode worker in disaggregation mode."""
@@ -456,14 +465,25 @@ def event_loop_normal_disagg_decode(self):
456465
batch = self.get_next_disagg_decode_batch_to_run()
457466
self.cur_batch = batch
458467

468+
prepare_dp_attn_flag = (
469+
self.server_args.enable_dp_attention
470+
or self.server_args.enable_sp_layernorm
471+
)
472+
459473
if batch:
460474
# Generate fake extend output.
461475
if batch.forward_mode.is_extend():
462476
# Note: Logprobs should be handled on the prefill engine.
463477
self.stream_output(batch.reqs, False)
478+
if prepare_dp_attn_flag:
479+
self._prepare_idle_batch_and_run(None)
464480
else:
481+
if prepare_dp_attn_flag:
482+
self.prepare_dp_attn_batch(batch)
465483
result = self.run_batch(batch)
466484
self.process_batch_result(batch, result)
485+
elif prepare_dp_attn_flag:
486+
batch, _ = self._prepare_idle_batch_and_run(None)
467487

468488
if batch is None and (
469489
len(self.disagg_decode_transfer_queue.queue)
@@ -480,7 +500,7 @@ def event_loop_normal_disagg_decode(self):
480500
def event_loop_overlap_disagg_decode(self):
481501
result_queue = deque()
482502
self.last_batch: Optional[ScheduleBatch] = None
483-
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
503+
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
484504

485505
while True:
486506
recv_reqs = self.recv_requests()
@@ -489,20 +509,41 @@ def event_loop_overlap_disagg_decode(self):
489509
self.process_decode_queue()
490510
batch = self.get_next_disagg_decode_batch_to_run()
491511
self.cur_batch = batch
492-
last_batch_is_extend = False
512+
last_batch_in_queue = False
513+
514+
prepare_dp_attn_flag = (
515+
self.server_args.enable_dp_attention
516+
or self.server_args.enable_sp_layernorm
517+
)
493518

494519
if batch:
495520
# Generate fake extend output.
496521
if batch.forward_mode.is_extend():
497522
# Note: Logprobs should be handled on the prefill engine.
498523
self.stream_output(batch.reqs, False)
499-
last_batch_is_extend = True
524+
if prepare_dp_attn_flag:
525+
batch_, result = self._prepare_idle_batch_and_run(
526+
None, delay_process=True
527+
)
528+
if batch_:
529+
result_queue.append((batch_.copy(), result))
530+
last_batch_in_queue = True
500531
else:
532+
if prepare_dp_attn_flag:
533+
self.prepare_dp_attn_batch(batch)
501534
result = self.run_batch(batch)
502535
result_queue.append((batch.copy(), result))
536+
last_batch_in_queue = True
537+
elif prepare_dp_attn_flag:
538+
batch, result = self._prepare_idle_batch_and_run(
539+
None, delay_process=True
540+
)
541+
if batch:
542+
result_queue.append((batch.copy(), result))
543+
last_batch_in_queue = True
503544

504545
# Process the results of the previous batch but skip if the last batch is extend
505-
if self.last_batch and not self.last_batch_is_extend:
546+
if self.last_batch and self.last_batch_in_queue:
506547
tmp_batch, tmp_result = result_queue.popleft()
507548
self.process_batch_result(tmp_batch, tmp_result)
508549

@@ -516,7 +557,7 @@ def event_loop_overlap_disagg_decode(self):
516557
self.new_token_ratio = self.init_new_token_ratio
517558

518559
self.last_batch = batch
519-
self.last_batch_is_extend = last_batch_is_extend
560+
self.last_batch_in_queue = last_batch_in_queue
520561

521562
def get_next_disagg_decode_batch_to_run(
522563
self: Scheduler,

python/sglang/srt/disaggregation/prefill.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def event_loop_normal_disagg_prefill(self):
187187
)
188188
self.process_prefill_chunk()
189189
batch = self.get_new_batch_prefill()
190+
191+
# Handle DP attention
192+
if (
193+
self.server_args.enable_dp_attention
194+
or self.server_args.enable_sp_layernorm
195+
):
196+
batch, _ = self.prepare_dp_attn_batch(batch)
197+
190198
self.cur_batch = batch
191199

192200
if batch:
@@ -217,6 +225,14 @@ def event_loop_overlap_disagg_prefill(self):
217225
)
218226
self.process_prefill_chunk()
219227
batch = self.get_new_batch_prefill()
228+
229+
# Handle DP attention
230+
if (
231+
self.server_args.enable_dp_attention
232+
or self.server_args.enable_sp_layernorm
233+
):
234+
batch, _ = self.prepare_dp_attn_batch(batch)
235+
220236
self.cur_batch = batch
221237

222238
if batch:

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
import setproctitle
2424
import zmq
2525

26+
from sglang.srt.disaggregation.utils import DisaggregationMode
2627
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
2728
from sglang.srt.managers.io_struct import (
2829
TokenizedEmbeddingReqInput,
2930
TokenizedGenerateReqInput,
3031
)
32+
from sglang.srt.managers.schedule_batch import Req
3133
from sglang.srt.managers.scheduler import run_scheduler_process
3234
from sglang.srt.server_args import PortArgs, ServerArgs
3335
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -226,9 +228,14 @@ def launch_tensor_parallel_group(
226228
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
227229
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
228230

229-
def round_robin_scheduler(self, req):
230-
self.workers[self.round_robin_counter].send_pyobj(req)
231-
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
231+
def round_robin_scheduler(self, req: Req):
232+
if self.server_args.disaggregation_mode == "null":
233+
self.workers[self.round_robin_counter].send_pyobj(req)
234+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
235+
self.workers
236+
)
237+
else:
238+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
232239

233240
def shortest_queue_scheduler(self, input_requests):
234241
raise NotImplementedError()

0 commit comments

Comments
 (0)