|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +import threading |
| 4 | +import traceback |
| 5 | + |
| 6 | +import uvicorn |
| 7 | +from fastapi import FastAPI, HTTPException |
| 8 | +from opentelemetry import context as otel_context |
| 9 | +from opentelemetry import propagate, trace |
| 10 | + |
| 11 | +from .store import get_store |
| 12 | +from .trainer import Datum, LoraConfig, TrainerEngine |
| 13 | + |
| 14 | +tracer = trace.get_tracer(__name__) |
| 15 | + |
| 16 | +engine = TrainerEngine() |
| 17 | + |
| 18 | + |
| 19 | +def _parse_datum(raw: dict) -> Datum: |
| 20 | + """Convert wire-format datum (with chunks) to our flat Datum type.""" |
| 21 | + chunks = raw.get("model_input", {}).get("chunks", []) |
| 22 | + tokens: list[int] = [] |
| 23 | + for chunk in chunks: |
| 24 | + tokens.extend(chunk.get("tokens", [])) |
| 25 | + |
| 26 | + loss_inputs = raw.get("loss_fn_inputs", {}) |
| 27 | + return Datum(model_input=tokens, loss_fn_inputs=loss_inputs) |
| 28 | + |
| 29 | + |
| 30 | +async def clock_cycle_loop() -> None: |
| 31 | + store = get_store() |
| 32 | + |
| 33 | + print("[WORKER] Training worker started.") |
| 34 | + |
| 35 | + while True: |
| 36 | + try: |
| 37 | + batch = await store.get_requests() |
| 38 | + if not batch: |
| 39 | + await asyncio.sleep(0.1) |
| 40 | + continue |
| 41 | + |
| 42 | + m_id = batch[0].get("model_id", "default") |
| 43 | + |
| 44 | + with tracer.start_as_current_span("clock_cycle_batch") as batch_span: |
| 45 | + batch_span.set_attribute("batch_size", len(batch)) |
| 46 | + batch_span.set_attribute("model_id", m_id) |
| 47 | + |
| 48 | + print(f"\n[CLOCK CYCLE] Popped {len(batch)} requests for tenant: {m_id}") |
| 49 | + |
| 50 | + SKIP_ADAPTER_SWITCH = {"create_model", "create_model_from_state"} |
| 51 | + if not any(r.get("type") in SKIP_ADAPTER_SWITCH for r in batch): |
| 52 | + try: |
| 53 | + await asyncio.to_thread(engine.set_active_adapter, m_id) |
| 54 | + except Exception as e: |
| 55 | + print(f"Failed to set adapter {m_id}: {e}") |
| 56 | + for r in batch: |
| 57 | + await store.set_future(r["req_id"], {"type": "RequestFailedResponse", "error_message": str(e)}) |
| 58 | + continue |
| 59 | + |
| 60 | + for r in batch: |
| 61 | + req_id = r["req_id"] |
| 62 | + req_type = r["type"] |
| 63 | + |
| 64 | + carrier = r.get("trace_context", {}) |
| 65 | + ctx = propagate.extract(carrier) if carrier else None |
| 66 | + token = otel_context.attach(ctx) if ctx else None |
| 67 | + |
| 68 | + try: |
| 69 | + match req_type: |
| 70 | + case "create_model": |
| 71 | + base_model = r["base_model"] |
| 72 | + raw_config = r.get("lora_config") or {} |
| 73 | + lora_config = LoraConfig(**{k: v for k, v in raw_config.items() if k in LoraConfig.model_fields}) |
| 74 | + |
| 75 | + await asyncio.to_thread(engine.load_base_model, base_model) |
| 76 | + await asyncio.to_thread(engine.create_adapter, m_id, lora_config) |
| 77 | + |
| 78 | + await store.set_future( |
| 79 | + req_id, |
| 80 | + { |
| 81 | + "model_id": m_id, |
| 82 | + "is_lora": True, |
| 83 | + "lora_rank": lora_config.rank, |
| 84 | + "type": "create_model", |
| 85 | + }, |
| 86 | + ) |
| 87 | + |
| 88 | + case "forward_backward": |
| 89 | + raw_data = r["data"] |
| 90 | + loss_fn = r["loss_fn"] |
| 91 | + loss_config = r.get("loss_config") |
| 92 | + |
| 93 | + typed_data = [_parse_datum(item) for item in raw_data] |
| 94 | + |
| 95 | + result = await asyncio.to_thread(engine.forward_backward, typed_data, loss_fn, loss_config, m_id) |
| 96 | + result["type"] = "forward_backward" |
| 97 | + await store.set_future(req_id, result) |
| 98 | + |
| 99 | + case "optim_step": |
| 100 | + adam_params = r["adam_params"] |
| 101 | + result = await asyncio.to_thread(engine.optim_step, adam_params, m_id) |
| 102 | + result["type"] = "optim_step" |
| 103 | + await store.set_future(req_id, result) |
| 104 | + |
| 105 | + case "sample": |
| 106 | + prompt_tokens = r["prompt_tokens"] |
| 107 | + max_tokens = r["max_tokens"] |
| 108 | + num_samples = r["num_samples"] |
| 109 | + temperature = r.get("temperature", 0.0) |
| 110 | + |
| 111 | + result = await asyncio.to_thread( |
| 112 | + engine.generate, |
| 113 | + prompt_tokens, |
| 114 | + max_tokens, |
| 115 | + num_samples, |
| 116 | + temperature, |
| 117 | + m_id, |
| 118 | + ) |
| 119 | + result["type"] = "sample" |
| 120 | + await store.set_future(req_id, result) |
| 121 | + |
| 122 | + case "save_state": |
| 123 | + state_path = r["state_path"] |
| 124 | + include_optimizer = bool(r.get("include_optimizer", False)) |
| 125 | + kind = r.get("kind", "state") |
| 126 | + |
| 127 | + result = await asyncio.to_thread(engine.save_state, m_id, state_path, include_optimizer, kind) |
| 128 | + result["type"] = "save_state" |
| 129 | + await store.set_future(req_id, result) |
| 130 | + |
| 131 | + case "save_weights_for_sampler" | "save_weights": |
| 132 | + await asyncio.to_thread(engine.save_adapter, m_id) |
| 133 | + await store.set_future(req_id, {"status": "ok", "type": req_type}) |
| 134 | + |
| 135 | + case _: |
| 136 | + print(f"Warning: Unhandled request type: {req_type}") |
| 137 | + await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": f"Unknown request type: {req_type}"}) |
| 138 | + |
| 139 | + except Exception as e: |
| 140 | + traceback.print_exc() |
| 141 | + await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)}) |
| 142 | + finally: |
| 143 | + if token: |
| 144 | + otel_context.detach(token) |
| 145 | + |
| 146 | + except asyncio.CancelledError: |
| 147 | + break |
| 148 | + except Exception as e: |
| 149 | + print(f"Error in clock cycle loop: {e}") |
| 150 | + traceback.print_exc() |
| 151 | + |
| 152 | + import redis |
| 153 | + |
| 154 | + if isinstance(e, redis.exceptions.ConnectionError): |
| 155 | + print("[worker] Destroying StateStore singleton to force Redis reconnection...") |
| 156 | + from . import store as store_mod |
| 157 | + |
| 158 | + store_mod._store_instance = None |
| 159 | + store = store_mod.get_store() |
| 160 | + |
| 161 | + await asyncio.sleep(1) |
| 162 | + |
| 163 | + |
| 164 | +def main() -> None: |
| 165 | + print("\n" + "=" * 50) |
| 166 | + print(" Open-RL PyTorch Training Worker") |
| 167 | + print("=" * 50) |
| 168 | + cuda_devs = os.getenv("CUDA_VISIBLE_DEVICES", "ALL") |
| 169 | + print(f"-> Hardware : CUDA_VISIBLE_DEVICES={cuda_devs}\n") |
| 170 | + |
| 171 | + preload_target = os.getenv("OPEN_RL_BASE_MODEL") or os.getenv("VLLM_MODEL") |
| 172 | + is_ready = False |
| 173 | + if preload_target: |
| 174 | + engine.load_base_model(preload_target) |
| 175 | + is_ready = True |
| 176 | + else: |
| 177 | + print("[WARNING] OPEN_RL_BASE_MODEL / VLLM_MODEL not provided. Cold-start penalty will apply on first request.") |
| 178 | + is_ready = True |
| 179 | + |
| 180 | + probe_app = FastAPI() |
| 181 | + |
| 182 | + @probe_app.get("/healthz") |
| 183 | + def healthz(): |
| 184 | + if is_ready: |
| 185 | + return {"status": "ready"} |
| 186 | + raise HTTPException(status_code=503, detail="Model Loading") |
| 187 | + |
| 188 | + def run_probe_server(): |
| 189 | + uvicorn.run(probe_app, host="0.0.0.0", port=8000, log_level="warning") |
| 190 | + |
| 191 | + threading.Thread(target=run_probe_server, daemon=True).start() |
| 192 | + asyncio.run(clock_cycle_loop()) |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == "__main__": |
| 196 | + main() |
0 commit comments