2929from sglang .srt .environ import envs
3030from sglang .srt .layers .dp_attention import compute_dp_attention_world_info
3131from sglang .srt .managers .io_struct import (
32+ ActiveRanksOutput ,
3233 BlockReqInput ,
3334 TokenizedEmbeddingReqInput ,
3435 TokenizedGenerateReqInput ,
@@ -158,6 +159,7 @@ def __init__(
158159 # Launch data parallel workers
159160 self .scheduler_procs = []
160161 self .workers : List [zmq .Socket ] = [None ] * server_args .dp_size
162+ self .status : List [int ] = [1 ] * server_args .dp_size
161163
162164 if server_args .enable_dp_attention :
163165 self .launch_dp_attention_schedulers (server_args , port_args )
@@ -179,8 +181,9 @@ def __init__(
179181 start_cpu_monitor_thread ("data_parallel_controller" )
180182
181183 def send_to_all_workers (self , obj ):
182- for worker in self .workers :
183- worker .send_pyobj (obj )
184+ for i , worker in enumerate (self .workers ):
185+ if self .status [i ] == 1 :
186+ worker .send_pyobj (obj )
184187
185188 def send_control_message (self , obj ):
186189 # Send control messages to first worker of tp group
@@ -190,6 +193,9 @@ def send_control_message(self, obj):
190193 def handle_load_update_req (self , obj ):
191194 self .dp_budget .update_budget (obj )
192195
196+ def update_active_ranks (self , ranks : ActiveRanksOutput ):
197+ self .status = ranks .status
198+
193199 def dispatching_with_trace (self , req : Req ):
194200 if self .server_args .enable_trace :
195201 trace_set_proc_propagate_context (req .rid , req .trace_context )
@@ -208,6 +214,7 @@ def init_dispatcher(self):
208214 (TokenizedEmbeddingReqInput , self .dispatching_with_trace ),
209215 (BlockReqInput , self .send_to_all_workers ),
210216 (WatchLoadUpdateReq , self .handle_load_update_req ),
217+ (ActiveRanksOutput , self .update_active_ranks ),
211218 ]
212219 )
213220 self ._request_dispatcher .add_fallback_fn (self .send_control_message )
@@ -479,8 +486,17 @@ def round_robin_scheduler(self, req: Req):
479486 if self .maybe_external_dp_rank_routing (req ):
480487 return
481488
482- self .workers [self .round_robin_counter ].send_pyobj (req )
483- self .round_robin_counter = (self .round_robin_counter + 1 ) % len (self .workers )
489+ while True :
490+ if self .status [self .round_robin_counter ] == 1 :
491+ logger .info (f"Choose worker { self .round_robin_counter } " )
492+ self .workers [self .round_robin_counter ].send_pyobj (req )
493+ self .round_robin_counter = (self .round_robin_counter + 1 ) % len (
494+ self .workers
495+ )
496+ break
497+ self .round_robin_counter = (self .round_robin_counter + 1 ) % len (
498+ self .workers
499+ )
484500
485501 def follow_bootstrap_room_scheduler (self , req : Req ):
486502 if self .maybe_external_dp_rank_routing (req ):
0 commit comments