Skip to content

Commit 22fa690

Browse files
ByronHsutarinkk
authored andcommitted
[PD] Support prefill overlap + Ensure no race condition (sgl-project#5609)
1 parent 5a7368a commit 22fa690

File tree

5 files changed

+107
-18
lines changed

5 files changed

+107
-18
lines changed

python/sglang/srt/disaggregation/prefill.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import annotations
2121

2222
import logging
23+
from collections import deque
2324
from typing import TYPE_CHECKING, List, Optional
2425

2526
import torch
@@ -204,6 +205,40 @@ def event_loop_normal_disagg_prefill(self):
204205
# Otherwise, it hangs under high concurrency
205206
self.running_batch.batch_is_full = False
206207

208+
@torch.no_grad()
209+
def event_loop_overlap_disagg_prefill(self):
210+
self.result_queue = deque()
211+
212+
while True:
213+
recv_reqs = self.recv_requests()
214+
self.process_input_requests(recv_reqs)
215+
self.waiting_queue.extend(
216+
self.disagg_prefill_pending_queue.pop_bootstrapped()
217+
)
218+
self.process_prefill_chunk()
219+
batch = self.get_new_batch_prefill()
220+
self.cur_batch = batch
221+
222+
if batch:
223+
result = self.run_batch(batch)
224+
self.result_queue.append((batch.copy(), result))
225+
226+
if self.last_batch:
227+
tmp_batch, tmp_result = self.result_queue.popleft()
228+
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
229+
230+
if len(self.disagg_prefill_inflight_queue) > 0:
231+
self.process_disagg_prefill_inflight_queue()
232+
233+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
234+
self.check_memory()
235+
self.new_token_ratio = self.init_new_token_ratio
236+
237+
self.last_batch = batch
238+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
239+
# Otherwise, it hangs under high concurrency
240+
self.running_batch.batch_is_full = False
241+
207242
def process_batch_result_disagg_prefill(
208243
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
209244
) -> None:
@@ -212,7 +247,26 @@ def process_batch_result_disagg_prefill(
212247
Adapted from process_batch_result_prefill
213248
"""
214249

215-
next_token_ids = result.next_token_ids.tolist()
250+
(
251+
logits_output,
252+
next_token_ids,
253+
extend_input_len_per_req,
254+
extend_logprob_start_len_per_req,
255+
bid,
256+
) = (
257+
result.logits_output,
258+
result.next_token_ids,
259+
result.extend_input_len_per_req,
260+
result.extend_logprob_start_len_per_req,
261+
result.bid,
262+
)
263+
264+
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
265+
if self.enable_overlap:
266+
# wait
267+
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
268+
else:
269+
next_token_ids = result.next_token_ids.tolist()
216270

217271
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
218272
req: Req
@@ -226,12 +280,8 @@ def process_batch_result_disagg_prefill(
226280
# being chunked reqs' prefill is not finished
227281
req.is_chunked -= 1
228282

229-
# TODO: Not sure if this is necessary
230-
if batch.next_batch_sampling_info:
231-
batch.next_batch_sampling_info.update_regex_vocab_mask()
232-
# We need to remove this for overlap schedule.
233-
self.current_stream.synchronize()
234-
batch.next_batch_sampling_info.sampling_info_done.set()
283+
if self.enable_overlap:
284+
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
235285

236286
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
237287
"""
@@ -276,20 +326,37 @@ def process_prefill_chunk(self: Scheduler) -> None:
276326
# only finished requests to running_batch.
277327
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
278328
self.tree_cache.cache_unfinished_req(self.chunked_req)
279-
self.send_kv_chunk(self.chunked_req)
329+
if (
330+
self.enable_overlap
331+
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
332+
self.chunked_req.tmp_end_idx = min(
333+
len(self.chunked_req.fill_ids),
334+
len(self.chunked_req.origin_input_ids),
335+
)
336+
else:
337+
self.send_kv_chunk(self.chunked_req)
280338
# chunked request keeps its rid but will get a new req_pool_idx
281339
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
282340
self.running_batch.batch_is_full = False
283341

284342
def send_kv_chunk(
285-
self: Scheduler, req: Req, token_id: Optional[int] = None
343+
self: Scheduler,
344+
req: Req,
345+
token_id: Optional[int] = None,
346+
end_idx: Optional[int] = None,
286347
) -> None:
287348
"""
288349
Send a prefilled chunk to the decode server
289350
"""
290351
page_size = self.token_to_kv_pool_allocator.page_size
291352
start_idx = req.start_send_idx
292-
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
353+
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
354+
# the resolved length is not the same as fill_ids's length
355+
end_idx = (
356+
end_idx
357+
if end_idx is not None
358+
else min(len(req.fill_ids), len(req.origin_input_ids))
359+
)
293360
last_chunk = token_id is not None
294361

295362
if (not last_chunk) and (
@@ -302,7 +369,7 @@ def send_kv_chunk(
302369
req.start_send_idx = end_idx
303370

304371
kv_indices = (
305-
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
372+
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
306373
.cpu()
307374
.numpy()
308375
)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,11 @@ def __init__(
539539
# The first output_id transferred from prefill instance.
540540
self.transferred_output_id: Optional[int] = None
541541

542+
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
543+
# This is because kv is not ready in `process_prefill_chunk`.
544+
# We use `tmp_end_idx` to store the end index of the kv cache to send.
545+
self.tmp_end_idx: int = -1
546+
542547
@property
543548
def seqlen(self):
544549
return len(self.origin_input_ids) + len(self.output_ids)

python/sglang/srt/managers/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2014,7 +2014,10 @@ def run_scheduler_process(
20142014
else:
20152015
scheduler.event_loop_normal()
20162016
elif disaggregation_mode == DisaggregationMode.PREFILL:
2017-
scheduler.event_loop_normal_disagg_prefill()
2017+
if scheduler.enable_overlap:
2018+
scheduler.event_loop_overlap_disagg_prefill()
2019+
else:
2020+
scheduler.event_loop_normal_disagg_prefill()
20182021
elif disaggregation_mode == DisaggregationMode.DECODE:
20192022
if scheduler.enable_overlap:
20202023
scheduler.event_loop_overlap_disagg_decode()

python/sglang/srt/server_args.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,6 @@ def __post_init__(self):
388388
if self.disaggregation_mode == "prefill":
389389
self.disable_cuda_graph = True
390390
logger.warning("Cuda graph is disabled for prefill server")
391-
self.disable_overlap_schedule = True
392-
logger.warning("Overlap scheduler is disabled for prefill server")
393391
elif self.disaggregation_mode == "decode":
394392
self.disable_radix_cache = True
395393
logger.warning("KV cache is forced as chunk cache for decode server")
Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
1-
prompt = [0] * 431
2-
31
import json
42

53
import requests
64

5+
prompt = """
6+
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
7+
8+
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
9+
10+
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
11+
12+
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
13+
14+
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
15+
16+
17+
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
18+
19+
Give your honest take on the above text:
20+
"""
21+
722
response = requests.post(
823
"http://0.0.0.0:8000/generate",
9-
json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}},
24+
json={"text": prompt, "sampling_params": {"temperature": 0}},
1025
)
1126

1227

13-
# print("Response content (raw):", response.content)
28+
response_json = response.json()
29+
print(response_json["text"])

0 commit comments

Comments
 (0)