Skip to content

Commit 8258eeb

Browse files
authored
refactor: split trainer into two files and clean up gateway (#47)
1 parent fc1344b commit 8258eeb

7 files changed

Lines changed: 1108 additions & 1217 deletions

File tree

server/kubernetes/distributed-lustre/05-trainer-worker.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spec:
2121
- name: trainer-worker
2222
image: gcr.io/cdrollouts-sunilarora/open-rl-server:latest
2323
imagePullPolicy: Always
24-
command: ["uv", "run", "python", "-m", "src.trainer"]
24+
command: ["uv", "run", "python", "-m", "src.clock_cycle"]
2525
env:
2626
- name: ENABLE_GCP_TRACE
2727
value: "1"

server/kubernetes/distributed-shared/05-trainer-worker.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spec:
2121
- name: trainer-worker
2222
image: gcr.io/cdrollouts-sunilarora/open-rl-server:latest
2323
imagePullPolicy: Always
24-
command: ["uv", "run", "python", "-m", "src.trainer"]
24+
command: ["uv", "run", "python", "-m", "src.clock_cycle"]
2525
env:
2626
- name: ENABLE_GCP_TRACE
2727
value: "1"

server/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ conflicts = [
4747
{ extra = "cpu" },
4848
{ extra = "vllm" },
4949
],
50+
[
51+
{ extra = "gpu" },
52+
{ extra = "vllm" },
53+
],
5054
]
5155

5256
[[tool.uv.index]]

server/src/clock_cycle.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import asyncio
2+
import os
3+
import threading
4+
import traceback
5+
6+
import uvicorn
7+
from fastapi import FastAPI, HTTPException
8+
from opentelemetry import context as otel_context
9+
from opentelemetry import propagate, trace
10+
11+
from .store import get_store
12+
from .trainer import Datum, LoraConfig, TrainerEngine
13+
14+
tracer = trace.get_tracer(__name__)
15+
16+
engine = TrainerEngine()
17+
18+
19+
def _parse_datum(raw: dict) -> Datum:
20+
"""Convert wire-format datum (with chunks) to our flat Datum type."""
21+
chunks = raw.get("model_input", {}).get("chunks", [])
22+
tokens: list[int] = []
23+
for chunk in chunks:
24+
tokens.extend(chunk.get("tokens", []))
25+
26+
loss_inputs = raw.get("loss_fn_inputs", {})
27+
return Datum(model_input=tokens, loss_fn_inputs=loss_inputs)
28+
29+
30+
async def clock_cycle_loop() -> None:
31+
store = get_store()
32+
33+
print("[WORKER] Training worker started.")
34+
35+
while True:
36+
try:
37+
batch = await store.get_requests()
38+
if not batch:
39+
await asyncio.sleep(0.1)
40+
continue
41+
42+
m_id = batch[0].get("model_id", "default")
43+
44+
with tracer.start_as_current_span("clock_cycle_batch") as batch_span:
45+
batch_span.set_attribute("batch_size", len(batch))
46+
batch_span.set_attribute("model_id", m_id)
47+
48+
print(f"\n[CLOCK CYCLE] Popped {len(batch)} requests for tenant: {m_id}")
49+
50+
SKIP_ADAPTER_SWITCH = {"create_model", "create_model_from_state"}
51+
if not any(r.get("type") in SKIP_ADAPTER_SWITCH for r in batch):
52+
try:
53+
await asyncio.to_thread(engine.set_active_adapter, m_id)
54+
except Exception as e:
55+
print(f"Failed to set adapter {m_id}: {e}")
56+
for r in batch:
57+
await store.set_future(r["req_id"], {"type": "RequestFailedResponse", "error_message": str(e)})
58+
continue
59+
60+
for r in batch:
61+
req_id = r["req_id"]
62+
req_type = r["type"]
63+
64+
carrier = r.get("trace_context", {})
65+
ctx = propagate.extract(carrier) if carrier else None
66+
token = otel_context.attach(ctx) if ctx else None
67+
68+
try:
69+
match req_type:
70+
case "create_model":
71+
base_model = r["base_model"]
72+
raw_config = r.get("lora_config") or {}
73+
lora_config = LoraConfig(**{k: v for k, v in raw_config.items() if k in LoraConfig.model_fields})
74+
75+
await asyncio.to_thread(engine.load_base_model, base_model)
76+
await asyncio.to_thread(engine.create_adapter, m_id, lora_config)
77+
78+
await store.set_future(
79+
req_id,
80+
{
81+
"model_id": m_id,
82+
"is_lora": True,
83+
"lora_rank": lora_config.rank,
84+
"type": "create_model",
85+
},
86+
)
87+
88+
case "forward_backward":
89+
raw_data = r["data"]
90+
loss_fn = r["loss_fn"]
91+
loss_config = r.get("loss_config")
92+
93+
typed_data = [_parse_datum(item) for item in raw_data]
94+
95+
result = await asyncio.to_thread(engine.forward_backward, typed_data, loss_fn, loss_config, m_id)
96+
result["type"] = "forward_backward"
97+
await store.set_future(req_id, result)
98+
99+
case "optim_step":
100+
adam_params = r["adam_params"]
101+
result = await asyncio.to_thread(engine.optim_step, adam_params, m_id)
102+
result["type"] = "optim_step"
103+
await store.set_future(req_id, result)
104+
105+
case "sample":
106+
prompt_tokens = r["prompt_tokens"]
107+
max_tokens = r["max_tokens"]
108+
num_samples = r["num_samples"]
109+
temperature = r.get("temperature", 0.0)
110+
111+
result = await asyncio.to_thread(
112+
engine.generate,
113+
prompt_tokens,
114+
max_tokens,
115+
num_samples,
116+
temperature,
117+
m_id,
118+
)
119+
result["type"] = "sample"
120+
await store.set_future(req_id, result)
121+
122+
case "save_state":
123+
state_path = r["state_path"]
124+
include_optimizer = bool(r.get("include_optimizer", False))
125+
kind = r.get("kind", "state")
126+
127+
result = await asyncio.to_thread(engine.save_state, m_id, state_path, include_optimizer, kind)
128+
result["type"] = "save_state"
129+
await store.set_future(req_id, result)
130+
131+
case "save_weights_for_sampler" | "save_weights":
132+
await asyncio.to_thread(engine.save_adapter, m_id)
133+
await store.set_future(req_id, {"status": "ok", "type": req_type})
134+
135+
case _:
136+
print(f"Warning: Unhandled request type: {req_type}")
137+
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": f"Unknown request type: {req_type}"})
138+
139+
except Exception as e:
140+
traceback.print_exc()
141+
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)})
142+
finally:
143+
if token:
144+
otel_context.detach(token)
145+
146+
except asyncio.CancelledError:
147+
break
148+
except Exception as e:
149+
print(f"Error in clock cycle loop: {e}")
150+
traceback.print_exc()
151+
152+
import redis
153+
154+
if isinstance(e, redis.exceptions.ConnectionError):
155+
print("[worker] Destroying StateStore singleton to force Redis reconnection...")
156+
from . import store as store_mod
157+
158+
store_mod._store_instance = None
159+
store = store_mod.get_store()
160+
161+
await asyncio.sleep(1)
162+
163+
164+
def main() -> None:
165+
print("\n" + "=" * 50)
166+
print(" Open-RL PyTorch Training Worker")
167+
print("=" * 50)
168+
cuda_devs = os.getenv("CUDA_VISIBLE_DEVICES", "ALL")
169+
print(f"-> Hardware : CUDA_VISIBLE_DEVICES={cuda_devs}\n")
170+
171+
preload_target = os.getenv("OPEN_RL_BASE_MODEL") or os.getenv("VLLM_MODEL")
172+
is_ready = False
173+
if preload_target:
174+
engine.load_base_model(preload_target)
175+
is_ready = True
176+
else:
177+
print("[WARNING] OPEN_RL_BASE_MODEL / VLLM_MODEL not provided. Cold-start penalty will apply on first request.")
178+
is_ready = True
179+
180+
probe_app = FastAPI()
181+
182+
@probe_app.get("/healthz")
183+
def healthz():
184+
if is_ready:
185+
return {"status": "ready"}
186+
raise HTTPException(status_code=503, detail="Model Loading")
187+
188+
def run_probe_server():
189+
uvicorn.run(probe_app, host="0.0.0.0", port=8000, log_level="warning")
190+
191+
threading.Thread(target=run_probe_server, daemon=True).start()
192+
asyncio.run(clock_cycle_loop())
193+
194+
195+
if __name__ == "__main__":
196+
main()

0 commit comments

Comments
 (0)