feat: wire in FFT worker#114
Conversation
Add OPEN_RL_WORKER_MODE=full as the gateway-selected path for launching one FFT training worker process per created model. The gateway owns the FFTWorkerLauncher lifecycle, starts a run-scoped clock_cycle process before enqueueing create_model/create_model_from_state, and shuts down launched workers with the FastAPI lifespan. The worker loop now chooses FFTTrainingWorker in full mode, drains only its model queue through Redis, saves full checkpoints for sampler exports, and keeps public Tinker SDK metadata LoRA-shaped until the client has native full fine-tuning support. Also moves create_model loading into the trainer workers and covers the LoRA/FFT create_model, checkpoint path, sampler-save, and optimizer behavior in tests.
| from tinker_cookbook.supervised.train import Config as TrainConfig | ||
| from tinker_cookbook.supervised.train import main as train | ||
| from tinker_cookbook.supervised.types import SupervisedDatasetBuilder | ||
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
There was a problem hiding this comment.
So nice to see FFT working with the tinker-cookbook :)
| param.requires_grad_(True) | ||
| self.trainable_params = trainable_model_parameters(self.model) | ||
|
|
||
| if ENABLE_GRADIENT_CHECKPOINTING: |
There was a problem hiding this comment.
This is applicable to FFT as well.
| return {"model_id": model_id, "is_lora": False, "base_model": base_model} | ||
| # SDK compatibility: the public client currently expects LoRA-shaped training metadata, | ||
| # even when this worker loaded full fine-tuned weights. | ||
| return {"model_id": model_id, "is_lora": True, "lora_rank": 16, "base_model": base_model} |
There was a problem hiding this comment.
Workers are backend processors so this code here feels like tight coupling between the API and the backend workers.
We should have an internal Result message between API <---> Worker and API should translate that appropriately to maintain compatibility with different types of clients (like tinker compatible in this case).
Also an opportunity to define typed messages between API <---> Worker because eventually they will communicate over the network.
| def get_default_model_name() -> str | None: | ||
| if is_single_process_mode(): | ||
| if is_single_process_mode() and not is_full_worker_mode(): | ||
| import clock_cycle |
There was a problem hiding this comment.
conditional import is typically a code smell. we should have explicit initialization and I thin it's related to initialization of worker on package import I saw earlier.
|
|
||
| env = {**os.environ, "OPEN_RL_WORKER_MODE": "full", "OPEN_RL_WORKER_MODEL_ID": model_id} | ||
| self.processes[model_id] = subprocess.Popen( | ||
| [sys.executable, "-m", "clock_cycle"], |
There was a problem hiding this comment.
We will have to rename clock_cycle to something else... it's a request processor actually. but not this PR.
| # Resolve relative names under TMP_DIR/checkpoints, leave absolute paths alone. | ||
| resolved_path = state_path if os.path.isabs(state_path) else os.path.join(TMP_DIR, "checkpoints", state_path) | ||
| model_id = str(uuid.uuid4()) | ||
| if is_full_worker_mode(): |
There was a problem hiding this comment.
Still thinking...one improvement we can make to keep the architecture decoupled. launching of the worker should be done in the request processing part instead of the API server. Let the API server enqueue requests.
Haven't thought through completely yet, so it's okay to do that as a follow-up.
There was a problem hiding this comment.
Changed this so create_model no longer launches a worker directly from the API handler. In FFT mode the gateway writes the create-model payload to a separate worker-launch queue. WorkerLaunchProcessor drains that queue, asks FFTWorkerManager to start the per-model worker process, and only then enqueues the original request onto the normal per-model training queue.
This PR builds on the worker split from the last PR and wires up the first usable full fine tuning mode. When OPEN_RL_WORKER_MODE=full is set, creating a model launches a dedicated FFT worker process for that job, and requests for that model are routed through Redis to that worker. LoRA keeps using the existing multi-tenant path. The worker launch logic lives in a small FFTWorkerLauncher class so the gateway owns the subprocess lifecycle in one place.
This also adds a GSM8K example so we have a concrete script for running and evaling FFT. Base line show ~1 to 5 percent accuracy and final shows ~34-35%
No snapshot integration yet.
Also while writing this I think we can start leveraging more features of FastAPI for typing the shape of the requests coming through. Seems like a good time to do it (not a high priority though)