Skip to content

[PD] Fix edge case and simplify large page size + chunked prefill #5589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,16 @@ def send_kv_chunk(
"""
Send a prefilled chunk to the decode server
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
last_chunk = token_id is not None

if (not last_chunk) and (
end_idx % page_size != 0
): # todo: remove the second condition
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size

# Update next start_send_idx
req.start_send_idx = end_idx
Expand All @@ -298,18 +306,21 @@ def send_kv_chunk(
.cpu()
.numpy()
)
if token_id is not None:
if last_chunk is True:
self.disagg_prefill_pending_queue.store_prefill_results(
req.metadata_buffer_index, token_id
)
is_last = token_id is not None
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
page_indices = kv_to_page_indices(kv_indices, page_size)

page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size
page_start_idx = start_idx // page_size
page_end_idx = page_start_idx + len(page_indices)

if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return

req.disagg_kv_sender.send(
page_indices, slice(page_start_idx, page_end_idx), is_last
page_indices, slice(page_start_idx, page_end_idx), last_chunk
)
14 changes: 3 additions & 11 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,14 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")


def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True):
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size

# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices

# if last chunk, send the last partial page
# if not last chunk, delay the last partial page to the next send
if is_last:
return kv_indices[::page_size] // page_size
else:
if len(kv_indices) % page_size == 0: # no partial page
return kv_indices[::page_size] // page_size
else: # partial page
return kv_indices[::page_size][:-1] // page_size
return kv_indices[::page_size] // page_size


def kv_to_page_num(num_kv_indices: int, page_size: int):
Expand Down
6 changes: 3 additions & 3 deletions scripts/playground/disaggregation/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
prompt = "Hello " * 16000
prompt = [0] * 431

import json

import requests

response = requests.post(
"http://0.0.0.0:8000/generate",
json={"text": prompt, "sampling_params": {"temperature": 0}},
json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}},
)


print("Response content (raw):", response.content)
# print("Response content (raw):", response.content)
Loading