Skip to content

Commit deded17

Browse files
authored
[PD] Fix edge case and simplify large page size + chunked prefill (#5589)
1 parent f29a718 commit deded17

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

python/sglang/srt/disaggregation/prefill.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,16 @@ def send_kv_chunk(
287287
"""
288288
Send a prefilled chunk to the decode server
289289
"""
290+
page_size = self.token_to_kv_pool_allocator.page_size
290291
start_idx = req.start_send_idx
291292
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
293+
last_chunk = token_id is not None
294+
295+
if (not last_chunk) and (
296+
end_idx % page_size != 0
297+
): # todo: remove the second condition
298+
# if not the last chunk and the last page is partial, delay the last partial page to the next send
299+
end_idx = end_idx - end_idx % page_size
292300

293301
# Update next start_send_idx
294302
req.start_send_idx = end_idx
@@ -298,18 +306,21 @@ def send_kv_chunk(
298306
.cpu()
299307
.numpy()
300308
)
301-
if token_id is not None:
309+
if last_chunk is True:
302310
self.disagg_prefill_pending_queue.store_prefill_results(
303311
req.metadata_buffer_index, token_id
304312
)
305-
is_last = token_id is not None
306-
page_indices = kv_to_page_indices(
307-
kv_indices, self.token_to_kv_pool_allocator.page_size
308-
)
313+
page_indices = kv_to_page_indices(kv_indices, page_size)
309314

310-
page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size
315+
page_start_idx = start_idx // page_size
311316
page_end_idx = page_start_idx + len(page_indices)
312317

318+
if len(page_indices) == 0:
319+
logger.info(
320+
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
321+
)
322+
return
323+
313324
req.disagg_kv_sender.send(
314-
page_indices, slice(page_start_idx, page_end_idx), is_last
325+
page_indices, slice(page_start_idx, page_end_idx), last_chunk
315326
)

python/sglang/srt/disaggregation/utils.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,14 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
7676
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
7777

7878

79-
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True):
79+
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
8080
# 1. The page is guaruanteed to be full except the last page.
8181
# 2. page index = kv_index // page_size
82-
82+
# The return vector is kv_indices[::page_size] // page_size
8383
if page_size == 1: # shortcut
8484
return kv_indices
8585

86-
# if last chunk, send the last partial page
87-
# if not last chunk, delay the last partial page to the next send
88-
if is_last:
89-
return kv_indices[::page_size] // page_size
90-
else:
91-
if len(kv_indices) % page_size == 0: # no partial page
92-
return kv_indices[::page_size] // page_size
93-
else: # partial page
94-
return kv_indices[::page_size][:-1] // page_size
86+
return kv_indices[::page_size] // page_size
9587

9688

9789
def kv_to_page_num(num_kv_indices: int, page_size: int):
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
prompt = "Hello " * 16000
1+
prompt = [0] * 431
22

33
import json
44

55
import requests
66

77
response = requests.post(
88
"http://0.0.0.0:8000/generate",
9-
json={"text": prompt, "sampling_params": {"temperature": 0}},
9+
json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}},
1010
)
1111

1212

13-
print("Response content (raw):", response.content)
13+
# print("Response content (raw):", response.content)

0 commit comments

Comments
 (0)