@@ -276,7 +276,7 @@ def __init__(
276276 # Queue for requests pending pre-allocation
277277 self .queue : List [DecodeRequest ] = []
278278 self .retracted_queue : List [Req ] = []
279- self .pending_reqs : List [Req ] = []
279+ self .pending_reqs : List [DecodeRequest ] = []
280280 self ._ensure_retry_count : Dict [str , int ] = {}
281281 self ._max_ensure_retries : int = 20 # scheduling cycles
282282 self ._ensure_last_attempt_time : Dict [str , float ] = {}
@@ -368,17 +368,20 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
368368 req .retraction_mb_id = None
369369 self .retracted_queue .append (req )
370370 else :
371+ decode_req = self ._create_receiver_and_enqueue (req )
372+
371373 # NOTE: fake transfer does not need to resolve prefill dp rank in the pending queue
372374 if _is_fake_transfer (req , self .scheduler .server_args ):
373- self . _create_receiver_and_enqueue ( req , 0 )
375+ decode_req . kv_receiver . init ( 0 )
374376 return
375377
376378 # Fast path: cache-only lookup, no network calls
377379 prefill_dp_rank = self ._resolve_prefill_dp_rank (req )
378380 if prefill_dp_rank is not None :
379- self ._create_receiver_and_enqueue (req , prefill_dp_rank )
380- else :
381- self .pending_reqs .append (req )
381+ decode_req .kv_receiver .init (prefill_dp_rank )
382+ return
383+
384+ self .pending_reqs .append (decode_req )
382385
383386 def _resolve_prefill_dp_rank (self , req : Req ) -> Optional [int ]:
384387 if req .disagg_prefill_dp_rank is not None :
@@ -396,7 +399,7 @@ def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]:
396399
397400 return None
398401
399- def _create_receiver_and_enqueue (self , req : Req , prefill_dp_rank : int ) -> None :
402+ def _create_receiver_and_enqueue (self , req : Req ) -> DecodeRequest :
400403 backend = (
401404 TransferBackend .FAKE
402405 if _is_fake_transfer (req , self .scheduler .server_args )
@@ -408,12 +411,11 @@ def _create_receiver_and_enqueue(self, req: Req, prefill_dp_rank: int) -> None:
408411 mgr = self .kv_manager ,
409412 bootstrap_addr = _bootstrap_addr (req ),
410413 bootstrap_room = req .bootstrap_room ,
411- prefill_dp_rank = prefill_dp_rank ,
412414 )
413415
414- self . queue . append (
415- DecodeRequest ( req = req , kv_receiver = kv_receiver , waiting_for_input = False )
416- )
416+ decode_req = DecodeRequest ( req = req , kv_receiver = kv_receiver )
417+ self . queue . append ( decode_req )
418+ return decode_req
417419
418420 def _check_if_req_exceed_kv_capacity (self , req : Req ) -> bool :
419421 if len (req .origin_input_ids ) > self .max_total_num_tokens :
@@ -511,12 +513,12 @@ def _update_handshake_waiters(
511513 raise ValueError (f"Unexpected poll case: { poll } " )
512514
513515 def _ensure_prefill_info (
514- self , addr_to_reqs : Dict [str , List [Req ]]
515- ) -> Tuple [Dict [str , List [Req ]], List [Req ]]:
516+ self , addr_to_reqs : Dict [str , List [DecodeRequest ]]
517+ ) -> Tuple [Dict [str , List [DecodeRequest ]], List [DecodeRequest ]]:
516518 """Non-blocking ensure parallel info for each addr.
517519 Returns (ready_addrs, remaining_reqs)."""
518- ready : Dict [str , List [Req ]] = {}
519- remaining : List [Req ] = []
520+ ready : Dict [str , List [DecodeRequest ]] = {}
521+ remaining : List [DecodeRequest ] = []
520522
521523 now = time .monotonic ()
522524 for bootstrap_addr , reqs in addr_to_reqs .items ():
@@ -543,13 +545,17 @@ def _ensure_prefill_info(
543545 if count >= self ._max_ensure_retries :
544546 error_msg = f"Could not fetch prefill parallel info from { bootstrap_addr } after { count } attempts"
545547 logger .error (error_msg )
546- for req in reqs :
548+ for decode_req in reqs :
547549 prepare_abort (
548- req , error_msg , status_code = HTTPStatus .INTERNAL_SERVER_ERROR
550+ decode_req .req ,
551+ error_msg ,
552+ status_code = HTTPStatus .INTERNAL_SERVER_ERROR ,
549553 )
550554 if self .scheduler .enable_metrics :
551555 self .scheduler .metrics_collector .increment_bootstrap_failed_reqs ()
552- self .scheduler .stream_output ([req ], req .return_logprob )
556+ self .scheduler .stream_output (
557+ [decode_req .req ], decode_req .req .return_logprob
558+ )
553559 del self ._ensure_retry_count [bootstrap_addr ]
554560 del self ._ensure_last_attempt_time [bootstrap_addr ]
555561 else :
@@ -558,46 +564,48 @@ def _ensure_prefill_info(
558564 return ready , remaining
559565
560566 def _resolve_pending_reqs (self ) -> None :
561- """Batch-resolve prefill_dp_ranks for pending requests and create receivers."""
567+ """Batch-resolve prefill_dp_ranks for pending requests and initialize receivers."""
562568 if not self .pending_reqs :
563569 return
564570
565571 # Group pending requests by bootstrap_addr
566- addr_to_reqs : Dict [str , List [Req ]] = {}
567- for req in self .pending_reqs :
568- addr = _bootstrap_addr (req )
569- addr_to_reqs .setdefault (addr , []).append (req )
572+ addr_to_reqs : Dict [str , List [DecodeRequest ]] = {}
573+ for decode_req in self .pending_reqs :
574+ addr = _bootstrap_addr (decode_req . req )
575+ addr_to_reqs .setdefault (addr , []).append (decode_req )
570576
571577 # Pass 1: ensure parallel info for each addr
572578 ready_addrs , remaining = self ._ensure_prefill_info (addr_to_reqs )
573579
574- # Pass 2: resolve dp rank for addrs whose info is available
575- resolved = []
576- for bootstrap_addr , reqs in ready_addrs .items ():
577- need_query : List [Req ] = []
578- for req in reqs :
579- prefill_dp_rank = self ._resolve_prefill_dp_rank (req )
580+ resolved : List [Tuple [DecodeRequest , int ]] = []
581+ for bootstrap_addr , decode_reqs in ready_addrs .items ():
582+ need_query : List [DecodeRequest ] = []
583+ for decode_req in decode_reqs :
584+ prefill_dp_rank = self ._resolve_prefill_dp_rank (decode_req .req )
580585 if prefill_dp_rank is not None :
581- resolved .append ((req , prefill_dp_rank ))
586+ resolved .append ((decode_req , prefill_dp_rank ))
582587 else :
583- need_query .append (req )
588+ need_query .append (decode_req )
584589
590+ # Pass 2: resolve dp rank for addrs whose info is available
585591 if need_query :
586- rooms = [req .bootstrap_room for req in need_query ]
592+ rooms = [decode_req . req .bootstrap_room for decode_req in need_query ]
587593 room_to_rank = CommonKVReceiver .query_prefill_dp_ranks (
588594 bootstrap_addr , rooms
589595 )
590- for req in need_query :
591- prefill_dp_rank = room_to_rank .get (str (req .bootstrap_room ))
596+ for decode_req in need_query :
597+ prefill_dp_rank = room_to_rank .get (
598+ str (decode_req .req .bootstrap_room )
599+ )
592600 if prefill_dp_rank is not None :
593- resolved .append ((req , int (prefill_dp_rank )))
601+ resolved .append ((decode_req , int (prefill_dp_rank )))
594602 else :
595- remaining .append (req )
603+ remaining .append (decode_req )
596604
597605 self .pending_reqs = remaining
598606
599- for req , prefill_dp_rank in resolved :
600- self . _create_receiver_and_enqueue ( req , prefill_dp_rank )
607+ for decode_req , prefill_dp_rank in resolved :
608+ decode_req . kv_receiver . init ( prefill_dp_rank )
601609
602610 def pop_preallocated (
603611 self , rids_to_check : Optional [List [str ]] = None
@@ -726,7 +734,7 @@ def pop_preallocated(
726734 )
727735 assert decode_req .metadata_buffer_index is not None
728736 page_indices = kv_to_page_indices (kv_indices , page_size )
729- decode_req .kv_receiver .init (
737+ decode_req .kv_receiver .send_metadata (
730738 page_indices , decode_req .metadata_buffer_index , state_indices
731739 )
732740 preallocated_reqs .append (decode_req )
0 commit comments