Skip to content

[PD] Support decode overlap schedule #5608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import logging
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple

Expand Down Expand Up @@ -475,6 +476,48 @@ def event_loop_normal_disagg_decode(self):

self.last_batch = batch

@torch.no_grad()
def event_loop_overlap_disagg_decode(self):
result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend

while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
last_batch_is_extend = False

if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
last_batch_is_extend = True
else:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))

# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and not self.last_batch_is_extend:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)

if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio

self.last_batch = batch
self.last_batch_is_extend = last_batch_is_extend

def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,10 @@ def run_scheduler_process(
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
scheduler.event_loop_normal_disagg_decode()
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()
else:
scheduler.event_loop_normal_disagg_decode()

except Exception:
traceback = get_exception_traceback()
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,12 @@ def __post_init__(self):
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("KV cache is forced as chunk cache for decode server")
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for decode server")
logger.warning("KV cache is forced as chunk cache for decode server")

os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
Expand Down
Loading