Skip to content

Commit 4f6eb01

Browse files
committed
Add locks etc. in scaling controller
1 parent fb5b72f commit 4f6eb01

File tree

1 file changed

+49
-38
lines changed

1 file changed

+49
-38
lines changed

areal/launcher/scaler/scaling_controller.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import sys
23
import threading
34
from pathlib import Path
@@ -31,8 +32,6 @@ def run_func(file_path: str, func_name: str, argv: list[str]):
3132
Import module by path and invoke the named function with a single `argv` list.
3233
This matches vllm_server.main(argv) which expects sys.argv[2:]-style args.
3334
"""
34-
import importlib.util
35-
3635
module_name = file_path.replace("/", "_").replace(".", "_")
3736
spec = importlib.util.spec_from_file_location(module_name, file_path)
3837
module = importlib.util.module_from_spec(spec)
@@ -53,13 +52,12 @@ def scale_up_vllm(
5352
experiment_name = cfg.experiment_name
5453
trial_name = cfg.trial_name
5554

56-
# allocation_mode
5755
allocation_mode = AllocationMode.from_str(cfg.allocation_mode)
5856
vllm_tp_size = allocation_mode.gen.tp_size
59-
n_existing_servers = allocation_mode.gen.dp_size
57+
n_existing_servers = expected - n_new_servers
6058

6159
cpus_per_gpu = cfg.launcher.inference_server_cpus_per_gpu
62-
mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu # MB per GPU
60+
mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu
6361

6462
# Submit new servers
6563
remote_runner = None # we’ll bind ray.remote per device type
@@ -117,74 +115,84 @@ def scale_up_vllm(
117115
"num_rollout": None,
118116
"vllm_entry_point": None,
119117
}
118+
shared_state_lock = threading.Lock()
120119

121120

122121
@app.post("/scale_up")
123122
async def http_scale_up(request: Request):
124123
"""
125-
Manual scale-up endpoint.
124+
Scaling controller endpoint.
126125
Example usage:
127126
curl -X POST localhost:8899/scale_up \
128127
-H "Content-Type: application/json" \
129128
-d '{"scaled_k": 1}'
130129
"""
131130
body = await request.json()
132131
scaled_k = int(body.get("scaled_k", 1))
133-
cfg = shared_state["cfg"]
134-
config_path = shared_state["config_path"]
135-
num_rollout = shared_state["num_rollout"]
136132

137-
if cfg is None or config_path is None:
138-
return {"status": "error", "msg": "Scale server not initialized yet"}
133+
with shared_state_lock:
134+
cfg = shared_state["cfg"]
135+
config_path = shared_state["config_path"]
136+
num_rollout = shared_state["num_rollout"]
137+
vllm_entry_point = shared_state["vllm_entry_point"]
138+
139+
# More complete initialization check
140+
if (
141+
cfg is None
142+
or config_path is None
143+
or num_rollout is None
144+
or vllm_entry_point is None
145+
):
146+
return {"status": "error", "msg": "Scale server not initialized yet"}
147+
148+
new_total = num_rollout + scaled_k
149+
shared_state["num_rollout"] = new_total
139150

140151
try:
141152
logger.info(f"[HTTP] Received manual scale-up request: {scaled_k}")
142-
shared_state["num_rollout"] = num_rollout + scaled_k
143-
144153
name_resolve.add("scale_up_request", {"scaled_k": int(scaled_k)}, replace=True)
154+
145155
scale_up_vllm(
146156
cfg,
147157
config_path,
148158
scaled_k,
149-
num_rollout + scaled_k,
150-
shared_state["vllm_entry_point"],
159+
new_total,
160+
vllm_entry_point,
151161
)
152162
try:
153163
name_resolve.delete("scale_up_done")
154164
except NameEntryNotFoundError:
155165
pass
156166

157-
name_resolve.add("scale_up_done", {"step": 0})
158-
logger.info(
159-
f"[HTTP] Scale-up done. Total rollout={shared_state['num_rollout']}"
160-
)
167+
name_resolve.add("scale_up_done", {"done": 1})
168+
logger.info(f"[HTTP] Scale-up done. Total rollout={new_total}")
161169
return {
162170
"status": "ok",
163171
"scaled_k": scaled_k,
164-
"new_total": shared_state["num_rollout"],
172+
"new_total": new_total,
165173
}
166174
except Exception as e:
167175
logger.error(f"[HTTP] Scale-up failed: {e}")
168176
return {"status": "error", "msg": str(e)}
169177

170178

171-
def run_http_server():
179+
def run_http_server(port: int):
172180
"""Run FastAPI server in background thread (non-blocking)."""
173-
config = Config(app, host="0.0.0.0", port=HTTP_SCALE_PORT, log_level="info")
181+
config = Config(app, host="0.0.0.0", port=port, log_level="info")
174182
server = Server(config)
175183

176184
def _serve():
177-
logger.info(f"[HTTP] Starting manual scale-up server on port {HTTP_SCALE_PORT}")
185+
logger.info(f"[HTTP] Starting scaling controller server on port {port}")
178186
server.run()
179187

180188
t = threading.Thread(target=_serve, daemon=False)
181189
t.start()
182-
logger.info("[HTTP] Manual scale-up service started in background.")
190+
logger.info("[HTTP] Scaling controller server started in background.")
183191

184192

185193
if __name__ == "__main__":
186194
if len(sys.argv) < 2:
187-
logger.info("Usage: python scaling_controller.py <config.yaml>")
195+
logger.info("Usage: python scaling_controller <config.yaml> ")
188196
sys.exit(1)
189197

190198
config_path = sys.argv[1]
@@ -193,35 +201,38 @@ def _serve():
193201
experiment_name = cfg.experiment_name
194202
trial_name = cfg.trial_name
195203

196-
# allocation_mode
197204
allocation_mode = AllocationMode.from_str(cfg.allocation_mode)
198-
# Set-the-experiments-configs for rollout ------------------
199205
num_rollout = allocation_mode.gen.dp_size
200206

201207
# Remove all the keys related to scaling before start the experiment
202208
try:
203209
name_resolve.delete("scale_up_request")
204210
except NameEntryNotFoundError:
205-
logger.info("no delete")
211+
pass
206212

207213
try:
208214
name_resolve.delete("scale_up_done")
209215
except NameEntryNotFoundError:
210216
pass
211-
# Init the ray and conncet it to existing cluster
217+
218+
# Init ray and connect it to existing cluster
212219
ray.init(address="auto", namespace=f"{experiment_name}_{trial_name}")
213220

214221
# Get port for scale up
215222
cfg.scaling = to_structured_cfg(cfg.scaling, ScalingConfig)
216-
HTTP_SCALE_PORT = cfg.scaling.scaling_controller_port
217-
218-
# Run http for scale-up
219-
run_http_server()
223+
port = cfg.scaling.scaling_controller_port
220224

221-
logger.info("[HTTP] Manual scale-up service started in background.")
225+
# Resolve vLLM entry point
222226
vllm_entry_point = str(Path(__file__).resolve().parent.parent / "vllm_server.py")
223-
shared_state["cfg"] = cfg
224-
shared_state["config_path"] = config_path
225-
shared_state["num_rollout"] = num_rollout
226-
shared_state["vllm_entry_point"] = vllm_entry_point
227+
228+
# Initialize shared_state atomically before starting HTTP server
229+
with shared_state_lock:
230+
shared_state["cfg"] = cfg
231+
shared_state["config_path"] = config_path
232+
shared_state["num_rollout"] = num_rollout
233+
shared_state["vllm_entry_point"] = vllm_entry_point
234+
227235
logger.info(f"[HTTP] num_rollout initialized to {num_rollout}")
236+
237+
# Run http for scale-up (after shared_state is fully initialized)
238+
run_http_server(port)

0 commit comments

Comments
 (0)