1+ import importlib .util
12import sys
23import threading
34from pathlib import Path
@@ -31,8 +32,6 @@ def run_func(file_path: str, func_name: str, argv: list[str]):
3132 Import module by path and invoke the named function with a single `argv` list.
3233 This matches vllm_server.main(argv) which expects sys.argv[2:]-style args.
3334 """
34- import importlib .util
35-
3635 module_name = file_path .replace ("/" , "_" ).replace ("." , "_" )
3736 spec = importlib .util .spec_from_file_location (module_name , file_path )
3837 module = importlib .util .module_from_spec (spec )
@@ -53,13 +52,12 @@ def scale_up_vllm(
5352 experiment_name = cfg .experiment_name
5453 trial_name = cfg .trial_name
5554
56- # allocation_mode
5755 allocation_mode = AllocationMode .from_str (cfg .allocation_mode )
5856 vllm_tp_size = allocation_mode .gen .tp_size
59- n_existing_servers = allocation_mode . gen . dp_size
57+ n_existing_servers = expected - n_new_servers
6058
6159 cpus_per_gpu = cfg .launcher .inference_server_cpus_per_gpu
62- mem_per_gpu = cfg .launcher .inference_server_mem_per_gpu # MB per GPU
60+ mem_per_gpu = cfg .launcher .inference_server_mem_per_gpu
6361
6462 # Submit new servers
6563 remote_runner = None # we’ll bind ray.remote per device type
@@ -117,74 +115,84 @@ def scale_up_vllm(
117115 "num_rollout" : None ,
118116 "vllm_entry_point" : None ,
119117}
118+ shared_state_lock = threading .Lock ()
120119
121120
122121@app .post ("/scale_up" )
123122async def http_scale_up (request : Request ):
124123 """
125- Manual scale-up endpoint.
124+ Scaling controller endpoint.
126125 Example usage:
127126 curl -X POST localhost:8899/scale_up \
128127 -H "Content-Type: application/json" \
129128 -d '{"scaled_k": 1}'
130129 """
131130 body = await request .json ()
132131 scaled_k = int (body .get ("scaled_k" , 1 ))
133- cfg = shared_state ["cfg" ]
134- config_path = shared_state ["config_path" ]
135- num_rollout = shared_state ["num_rollout" ]
136132
137- if cfg is None or config_path is None :
138- return {"status" : "error" , "msg" : "Scale server not initialized yet" }
133+ with shared_state_lock :
134+ cfg = shared_state ["cfg" ]
135+ config_path = shared_state ["config_path" ]
136+ num_rollout = shared_state ["num_rollout" ]
137+ vllm_entry_point = shared_state ["vllm_entry_point" ]
138+
139+ # More complete initialization check
140+ if (
141+ cfg is None
142+ or config_path is None
143+ or num_rollout is None
144+ or vllm_entry_point is None
145+ ):
146+ return {"status" : "error" , "msg" : "Scale server not initialized yet" }
147+
148+ new_total = num_rollout + scaled_k
149+ shared_state ["num_rollout" ] = new_total
139150
140151 try :
141152 logger .info (f"[HTTP] Received manual scale-up request: { scaled_k } " )
142- shared_state ["num_rollout" ] = num_rollout + scaled_k
143-
144153 name_resolve .add ("scale_up_request" , {"scaled_k" : int (scaled_k )}, replace = True )
154+
145155 scale_up_vllm (
146156 cfg ,
147157 config_path ,
148158 scaled_k ,
149- num_rollout + scaled_k ,
150- shared_state [ " vllm_entry_point" ] ,
159+ new_total ,
160+ vllm_entry_point ,
151161 )
152162 try :
153163 name_resolve .delete ("scale_up_done" )
154164 except NameEntryNotFoundError :
155165 pass
156166
157- name_resolve .add ("scale_up_done" , {"step" : 0 })
158- logger .info (
159- f"[HTTP] Scale-up done. Total rollout={ shared_state ['num_rollout' ]} "
160- )
167+ name_resolve .add ("scale_up_done" , {"done" : 1 })
168+ logger .info (f"[HTTP] Scale-up done. Total rollout={ new_total } " )
161169 return {
162170 "status" : "ok" ,
163171 "scaled_k" : scaled_k ,
164- "new_total" : shared_state [ "num_rollout" ] ,
172+ "new_total" : new_total ,
165173 }
166174 except Exception as e :
167175 logger .error (f"[HTTP] Scale-up failed: { e } " )
168176 return {"status" : "error" , "msg" : str (e )}
169177
170178
171- def run_http_server ():
179+ def run_http_server (port : int ):
172180 """Run FastAPI server in background thread (non-blocking)."""
173- config = Config (app , host = "0.0.0.0" , port = HTTP_SCALE_PORT , log_level = "info" )
181+ config = Config (app , host = "0.0.0.0" , port = port , log_level = "info" )
174182 server = Server (config )
175183
176184 def _serve ():
177- logger .info (f"[HTTP] Starting manual scale-up server on port { HTTP_SCALE_PORT } " )
185+ logger .info (f"[HTTP] Starting scaling controller server on port { port } " )
178186 server .run ()
179187
180188 t = threading .Thread (target = _serve , daemon = False )
181189 t .start ()
182- logger .info ("[HTTP] Manual scale-up service started in background." )
190+ logger .info ("[HTTP] Scaling controller server started in background." )
183191
184192
185193if __name__ == "__main__" :
186194 if len (sys .argv ) < 2 :
187- logger .info ("Usage: python scaling_controller.py <config.yaml>" )
195+ logger .info ("Usage: python scaling_controller <config.yaml> " )
188196 sys .exit (1 )
189197
190198 config_path = sys .argv [1 ]
@@ -193,35 +201,38 @@ def _serve():
193201 experiment_name = cfg .experiment_name
194202 trial_name = cfg .trial_name
195203
196- # allocation_mode
197204 allocation_mode = AllocationMode .from_str (cfg .allocation_mode )
198- # Set-the-experiments-configs for rollout ------------------
199205 num_rollout = allocation_mode .gen .dp_size
200206
201207 # Remove all the keys related to scaling before start the experiment
202208 try :
203209 name_resolve .delete ("scale_up_request" )
204210 except NameEntryNotFoundError :
205- logger . info ( "no delete" )
211+ pass
206212
207213 try :
208214 name_resolve .delete ("scale_up_done" )
209215 except NameEntryNotFoundError :
210216 pass
211- # Init the ray and conncet it to existing cluster
217+
218+ # Init ray and connect it to existing cluster
212219 ray .init (address = "auto" , namespace = f"{ experiment_name } _{ trial_name } " )
213220
214221 # Get port for scale up
215222 cfg .scaling = to_structured_cfg (cfg .scaling , ScalingConfig )
216- HTTP_SCALE_PORT = cfg .scaling .scaling_controller_port
217-
218- # Run http for scale-up
219- run_http_server ()
223+ port = cfg .scaling .scaling_controller_port
220224
221- logger . info ( "[HTTP] Manual scale-up service started in background." )
225+ # Resolve vLLM entry point
222226 vllm_entry_point = str (Path (__file__ ).resolve ().parent .parent / "vllm_server.py" )
223- shared_state ["cfg" ] = cfg
224- shared_state ["config_path" ] = config_path
225- shared_state ["num_rollout" ] = num_rollout
226- shared_state ["vllm_entry_point" ] = vllm_entry_point
227+
228+ # Initialize shared_state atomically before starting HTTP server
229+ with shared_state_lock :
230+ shared_state ["cfg" ] = cfg
231+ shared_state ["config_path" ] = config_path
232+ shared_state ["num_rollout" ] = num_rollout
233+ shared_state ["vllm_entry_point" ] = vllm_entry_point
234+
227235 logger .info (f"[HTTP] num_rollout initialized to { num_rollout } " )
236+
237+ # Run http for scale-up (after shared_state is fully initialized)
238+ run_http_server (port )
0 commit comments