Skip to content

Commit 4455d17

Browse files
weireweireWeiliangl User
andauthored
[PD] Refactor Disagg Conn and Fix Hang with total_request/total_tokens Balancing (#21299)
Co-authored-by: Weiliangl User <weiliangl@login-node.hosted.internal>
1 parent acd37d8 commit 4455d17

8 files changed

Lines changed: 161 additions & 79 deletions

File tree

python/sglang/srt/disaggregation/base/conn.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,23 @@ def __init__(
122122

123123
@abstractmethod
124124
def init(
125+
self,
126+
prefill_dp_rank: int,
127+
):
128+
"""
129+
Resolve bootstrap metadata and mark the receiver ready for transfer metadata.
130+
"""
131+
...
132+
133+
@abstractmethod
134+
def send_metadata(
125135
self,
126136
kv_indices: npt.NDArray[np.int32],
127137
aux_index: Optional[int] = None,
128138
state_indices: Optional[List[int]] = None,
129139
):
130140
"""
131-
Set req's index metadata locally or notify the prefill server about the kv indices, aux index, and state_indices.
141+
Notify the prefill server about the kv indices, aux index, and state_indices.
132142
"""
133143
...
134144

python/sglang/srt/disaggregation/common/conn.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,20 +489,31 @@ def __init__(
489489
mgr: CommonKVManager,
490490
bootstrap_addr: str,
491491
bootstrap_room: Optional[int] = None,
492-
prefill_dp_rank: Optional[int] = None,
493492
):
494493
self.bootstrap_room = bootstrap_room
495494
self.bootstrap_addr = bootstrap_addr
496495
self.kv_mgr = mgr
496+
self.conclude_state: Optional[KVPoll] = None
497+
self.bootstrap_infos = None
498+
self.prefill_info = None
499+
self.prefill_dp_rank = None
500+
self.target_tp_rank = None
501+
self.target_tp_ranks = None
502+
self.target_cp_ranks = None
503+
self.target_pp_ranks = None
504+
self.required_dst_info_num = None
505+
self.required_prefill_response_num = None
506+
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
497507
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
498508

509+
def init(self, prefill_dp_rank: int):
499510
if self.bootstrap_addr not in self.kv_mgr.prefill_info_table:
500511
self.kv_mgr.record_failure(
501512
self.bootstrap_room,
502513
f"Prefill server with bootstrap_addr: {self.bootstrap_addr} is healthy before, but now it is down. Request (bootstrap_room: {self.bootstrap_room}) has been marked as failed.",
503514
)
515+
self.conclude_state = KVPoll.Failed
504516
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
505-
self.bootstrap_infos = None
506517
return
507518

508519
# Read pre-computed rank mapping from prefill_info (computed in try_ensure_parallel_info)
@@ -520,11 +531,9 @@ def __init__(
520531
self.required_prefill_response_num
521532
)
522533

523-
assert (
524-
prefill_dp_rank is not None
525-
), "prefill_dp_rank must be resolved before creating receiver"
526534
self.prefill_dp_rank = prefill_dp_rank
527535
self._setup_bootstrap_infos()
536+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
528537

529538
def _setup_bootstrap_infos(self):
530539
all_bootstrap_infos = []
@@ -562,6 +571,7 @@ def _setup_bootstrap_infos(self):
562571
self.bootstrap_room,
563572
f"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}",
564573
)
574+
self.conclude_state = KVPoll.Failed
565575
self.kv_mgr.update_status(
566576
self.bootstrap_room, KVPoll.Failed
567577
)
@@ -645,6 +655,14 @@ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
645655
def _register_kv_args(self):
646656
pass
647657

658+
def send_metadata(
659+
self,
660+
kv_indices: npt.NDArray[np.int32],
661+
aux_index: Optional[int] = None,
662+
state_indices: Optional[List[int]] = None,
663+
):
664+
raise NotImplementedError
665+
648666
def failure_exception(self):
649667
raise Exception("Fake KVReceiver Exception")
650668

python/sglang/srt/disaggregation/decode.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__(
276276
# Queue for requests pending pre-allocation
277277
self.queue: List[DecodeRequest] = []
278278
self.retracted_queue: List[Req] = []
279-
self.pending_reqs: List[Req] = []
279+
self.pending_reqs: List[DecodeRequest] = []
280280
self._ensure_retry_count: Dict[str, int] = {}
281281
self._max_ensure_retries: int = 20 # scheduling cycles
282282
self._ensure_last_attempt_time: Dict[str, float] = {}
@@ -368,17 +368,20 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
368368
req.retraction_mb_id = None
369369
self.retracted_queue.append(req)
370370
else:
371+
decode_req = self._create_receiver_and_enqueue(req)
372+
371373
# NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue
372374
if _is_fake_transfer(req, self.scheduler.server_args):
373-
self._create_receiver_and_enqueue(req, 0)
375+
decode_req.kv_receiver.init(0)
374376
return
375377

376378
# Fast path: cache-only lookup, no network calls
377379
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
378380
if prefill_dp_rank is not None:
379-
self._create_receiver_and_enqueue(req, prefill_dp_rank)
380-
else:
381-
self.pending_reqs.append(req)
381+
decode_req.kv_receiver.init(prefill_dp_rank)
382+
return
383+
384+
self.pending_reqs.append(decode_req)
382385

383386
def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:
384387
if req.disagg_prefill_dp_rank is not None:
@@ -396,7 +399,7 @@ def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:
396399

397400
return None
398401

399-
def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
402+
def _create_receiver_and_enqueue(self, req: Req) -> DecodeRequest:
400403
backend = (
401404
TransferBackend.FAKE
402405
if _is_fake_transfer(req, self.scheduler.server_args)
@@ -408,12 +411,11 @@ def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
408411
mgr=self.kv_manager,
409412
bootstrap_addr=_bootstrap_addr(req),
410413
bootstrap_room=req.bootstrap_room,
411-
prefill_dp_rank=prefill_dp_rank,
412414
)
413415

414-
self.queue.append(
415-
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
416-
)
416+
decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver)
417+
self.queue.append(decode_req)
418+
return decode_req
417419

418420
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
419421
if len(req.origin_input_ids) > self.max_total_num_tokens:
@@ -511,12 +513,12 @@ def _update_handshake_waiters(
511513
raise ValueError(f"Unexpected poll case: {poll}")
512514

513515
def _ensure_prefill_info(
514-
self, addr_to_reqs: Dict[str, List[Req]]
515-
) -> Tuple[Dict[str, List[Req]], List[Req]]:
516+
self, addr_to_reqs: Dict[str, List[DecodeRequest]]
517+
) -> Tuple[Dict[str, List[DecodeRequest]], List[DecodeRequest]]:
516518
"""Non-blocking ensure parallel info for each addr.
517519
Returns (ready_addrs, remaining_reqs)."""
518-
ready: Dict[str, List[Req]] = {}
519-
remaining: List[Req] = []
520+
ready: Dict[str, List[DecodeRequest]] = {}
521+
remaining: List[DecodeRequest] = []
520522

521523
now = time.monotonic()
522524
for bootstrap_addr, reqs in addr_to_reqs.items():
@@ -543,13 +545,17 @@ def _ensure_prefill_info(
543545
if count >= self._max_ensure_retries:
544546
error_msg = f"Could not fetch prefill parallel info from {bootstrap_addr} after {count} attempts"
545547
logger.error(error_msg)
546-
for req in reqs:
548+
for decode_req in reqs:
547549
prepare_abort(
548-
req, error_msg, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
550+
decode_req.req,
551+
error_msg,
552+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
549553
)
550554
if self.scheduler.enable_metrics:
551555
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
552-
self.scheduler.stream_output([req], req.return_logprob)
556+
self.scheduler.stream_output(
557+
[decode_req.req], decode_req.req.return_logprob
558+
)
553559
del self._ensure_retry_count[bootstrap_addr]
554560
del self._ensure_last_attempt_time[bootstrap_addr]
555561
else:
@@ -558,46 +564,48 @@ def _ensure_prefill_info(
558564
return ready, remaining
559565

560566
def _resolve_pending_reqs(self) -> None:
561-
"""Batch-resolve prefill_dp_ranks for pending requests and create receivers."""
567+
"""Batch-resolve prefill_dp_ranks for pending requests and initialize receivers."""
562568
if not self.pending_reqs:
563569
return
564570

565571
# Group pending requests by bootstrap_addr
566-
addr_to_reqs: Dict[str, List[Req]] = {}
567-
for req in self.pending_reqs:
568-
addr = _bootstrap_addr(req)
569-
addr_to_reqs.setdefault(addr, []).append(req)
572+
addr_to_reqs: Dict[str, List[DecodeRequest]] = {}
573+
for decode_req in self.pending_reqs:
574+
addr = _bootstrap_addr(decode_req.req)
575+
addr_to_reqs.setdefault(addr, []).append(decode_req)
570576

571577
# Pass 1: ensure parallel info for each addr
572578
ready_addrs, remaining = self._ensure_prefill_info(addr_to_reqs)
573579

574-
# Pass 2: resolve dp rank for addrs whose info is available
575-
resolved = []
576-
for bootstrap_addr, reqs in ready_addrs.items():
577-
need_query: List[Req] = []
578-
for req in reqs:
579-
prefill_dp_rank = self._resolve_prefill_dp_rank(req)
580+
resolved: List[Tuple[DecodeRequest, int]] = []
581+
for bootstrap_addr, decode_reqs in ready_addrs.items():
582+
need_query: List[DecodeRequest] = []
583+
for decode_req in decode_reqs:
584+
prefill_dp_rank = self._resolve_prefill_dp_rank(decode_req.req)
580585
if prefill_dp_rank is not None:
581-
resolved.append((req, prefill_dp_rank))
586+
resolved.append((decode_req, prefill_dp_rank))
582587
else:
583-
need_query.append(req)
588+
need_query.append(decode_req)
584589

590+
# Pass 2: resolve dp rank for addrs whose info is available
585591
if need_query:
586-
rooms = [req.bootstrap_room for req in need_query]
592+
rooms = [decode_req.req.bootstrap_room for decode_req in need_query]
587593
room_to_rank = CommonKVReceiver.query_prefill_dp_ranks(
588594
bootstrap_addr, rooms
589595
)
590-
for req in need_query:
591-
prefill_dp_rank = room_to_rank.get(str(req.bootstrap_room))
596+
for decode_req in need_query:
597+
prefill_dp_rank = room_to_rank.get(
598+
str(decode_req.req.bootstrap_room)
599+
)
592600
if prefill_dp_rank is not None:
593-
resolved.append((req, int(prefill_dp_rank)))
601+
resolved.append((decode_req, int(prefill_dp_rank)))
594602
else:
595-
remaining.append(req)
603+
remaining.append(decode_req)
596604

597605
self.pending_reqs = remaining
598606

599-
for req, prefill_dp_rank in resolved:
600-
self._create_receiver_and_enqueue(req, prefill_dp_rank)
607+
for decode_req, prefill_dp_rank in resolved:
608+
decode_req.kv_receiver.init(prefill_dp_rank)
601609

602610
def pop_preallocated(
603611
self, rids_to_check: Optional[List[str]] = None
@@ -726,7 +734,7 @@ def pop_preallocated(
726734
)
727735
assert decode_req.metadata_buffer_index is not None
728736
page_indices = kv_to_page_indices(kv_indices, page_size)
729-
decode_req.kv_receiver.init(
737+
decode_req.kv_receiver.send_metadata(
730738
page_indices, decode_req.metadata_buffer_index, state_indices
731739
)
732740
preallocated_reqs.append(decode_req)

python/sglang/srt/disaggregation/fake/conn.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,33 @@ def __init__(
8282
mgr: BaseKVManager,
8383
bootstrap_addr: str,
8484
bootstrap_room: Optional[int] = None,
85-
prefill_dp_rank: Optional[int] = None,
8685
):
87-
self.has_init = False
86+
self.bootstrap_done = False
87+
self.has_sent_metadata = False
8888

8989
def poll(self) -> KVPoll:
90-
if self.has_init is False:
91-
# Assume handshake completed instantly
90+
if not self.bootstrap_done:
91+
return KVPoll.Bootstrapping
92+
if not self.has_sent_metadata:
9293
return KVPoll.WaitingForInput
93-
else:
94-
# Assume transfer completed instantly
95-
logger.debug("FakeKVReceiver poll success")
96-
return KVPoll.Success
94+
logger.debug("FakeKVReceiver poll success")
95+
return KVPoll.Success
9796

9897
def init(
98+
self,
99+
prefill_dp_rank: int,
100+
):
101+
self.bootstrap_done = True
102+
103+
def send_metadata(
99104
self,
100105
kv_indices: list[int],
101106
aux_index: Optional[int] = None,
102107
state_indices: Optional[List[int]] = None,
103108
):
104-
self.has_init = True
109+
self.has_sent_metadata = True
105110
logger.debug(
106-
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
111+
f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
107112
)
108113

109114
def failure_exception(self):

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,15 +1238,10 @@ def __init__(
12381238
mgr: MooncakeKVManager,
12391239
bootstrap_addr: str,
12401240
bootstrap_room: Optional[int] = None,
1241-
prefill_dp_rank: Optional[int] = None,
12421241
):
12431242
self.session_id = mgr.get_session_id()
1244-
self.conclude_state = None
12451243
self.init_time = None
1246-
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
1247-
1248-
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
1249-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
1244+
super().__init__(mgr, bootstrap_addr, bootstrap_room)
12501245

12511246
def _register_kv_args(self):
12521247
for bootstrap_info in self.bootstrap_infos:
@@ -1297,6 +1292,12 @@ def _register_kv_args(self):
12971292
)
12981293

12991294
def init(
1295+
self,
1296+
prefill_dp_rank: int,
1297+
):
1298+
super().init(prefill_dp_rank)
1299+
1300+
def send_metadata(
13001301
self,
13011302
kv_indices: npt.NDArray[np.int32],
13021303
aux_index: Optional[int] = None,

python/sglang/srt/disaggregation/mori/conn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,17 +985,18 @@ def __init__(
985985
mgr: MoriKVManager,
986986
bootstrap_addr: str,
987987
bootstrap_room: Optional[int] = None,
988-
prefill_dp_rank: Optional[int] = None,
989988
):
990-
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
991-
self.conclude_state: Optional[KVPoll] = None
989+
super().__init__(mgr, bootstrap_addr, bootstrap_room)
992990
self.init_time: Optional[float] = None
993-
if self.bootstrap_room is None or self.bootstrap_infos is None:
991+
992+
def init(
993+
self,
994+
prefill_dp_rank: int,
995+
):
996+
super().init(prefill_dp_rank)
997+
if self.bootstrap_room is None:
994998
return
995-
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
996999
self.kv_mgr.room_to_bootstrap_addr[self.bootstrap_room] = self.bootstrap_addr
997-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
998-
self._register_kv_args()
9991000

10001001
def _register_kv_args(self):
10011002
if self.bootstrap_infos is None:
@@ -1029,7 +1030,7 @@ def _register_kv_args(self):
10291030
]
10301031
)
10311032

1032-
def init(
1033+
def send_metadata(
10331034
self,
10341035
kv_indices: npt.NDArray[np.int32],
10351036
aux_index: Optional[int] = None,

0 commit comments

Comments
 (0)