Skip to content

Commit 5f3957e

Browse files
authored
feat: wire in FFT worker (#114)
* feat: add full worker mode support Add OPEN_RL_WORKER_MODE=full as the gateway-selected path for launching one FFT training worker process per created model. The gateway owns the FFTWorkerLauncher lifecycle, starts a run-scoped clock_cycle process before enqueueing create_model/create_model_from_state, and shuts down launched workers with the FastAPI lifespan. The worker loop now chooses FFTTrainingWorker in full mode, drains only its model queue through Redis, saves full checkpoints for sampler exports, and keeps public Tinker SDK metadata LoRA-shaped until the client has native full fine-tuning support. Also moves create_model loading into the trainer workers and covers the LoRA/FFT create_model, checkpoint path, sampler-save, and optimizer behavior in tests. * docs: add gsm8k fft example * fix: decouple fft worker launch
1 parent fed9fd7 commit 5f3957e

12 files changed

Lines changed: 1121 additions & 210 deletions

examples/sft/gsm8k/README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# GSM8K full fine-tuning
2+
3+
Full-parameter SFT of a small model on GSM8K, driven through the OpenRL gateway
4+
with the Tinker SDK.
5+
6+
## Why full fine-tuning goes through dedicated workers
7+
8+
The public Tinker SDK entrypoint is still `create_lora_training_client()`. For
9+
now, OpenRL routes that same client flow to a full fine-tuning worker when the
10+
gateway is started with `OPEN_RL_ENABLE_FFT=true`.
11+
12+
## Run
13+
14+
This branch launches one full fine-tuning worker process per created model. That
15+
worker shares requests and futures with the gateway through Redis.
16+
17+
Start from the repository root in separate terminals.
18+
19+
### Terminal 1: Redis
20+
21+
```bash
22+
redis-server --port 6379 --save "" --appendonly no
23+
```
24+
25+
### Terminal 2: Gateway
26+
27+
```bash
28+
cd src/server
29+
REDIS_URL=redis://127.0.0.1:6379 \
30+
OPEN_RL_ENABLE_FFT=true \
31+
BASE_MODEL=Qwen/Qwen2.5-0.5B \
32+
SAMPLING_BACKEND=torch \
33+
uv run --extra gpu python -m uvicorn gateway:app --host 127.0.0.1 --port 9003
34+
```
35+
36+
### Terminal 3: SFT Job
37+
38+
```bash
39+
uv --project examples run python examples/sft/gsm8k/gsm8k_sft.py \
40+
--log-path=examples/sft/gsm8k/artifacts/job_a \
41+
--max-steps=20 \
42+
--base-model=Qwen/Qwen2.5-0.5B
43+
```
44+
45+
Training uses `tinker_cookbook.supervised.train`, so batching, LR scheduling,
46+
metric logging, and final checkpoint export are handled by the cookbook loop.
47+
The example deletes an existing log directory by default so stale checkpoint
48+
metadata does not trigger resume. The training script prints
49+
`eval_model_path=...` when it can resolve a final checkpoint path.
50+
51+
## Eval
52+
53+
Eval is decoupled from OpenRL. Point vLLM at the saved Hugging Face checkpoint:
54+
55+
```bash
56+
python examples/sft/gsm8k/vllm_eval.py \
57+
--path <eval_model_path> \
58+
--data gsm8k_test.json
59+
```
60+
61+
## Result
62+
63+
Single-job result from the original FFT prototype run:
64+
65+
| Setup | GSM8K |
66+
| --- | --- |
67+
| Qwen2.5-0.5B base, 0-shot exact match on 250 examples | ~1.5% |
68+
| Qwen2.5-0.5B after full-FT SFT, 1 epoch, lr 2e-5 | ~36% |
69+
70+
Files:
71+
72+
- `gsm8k_sft.py`: training via the OpenRL/Tinker server.
73+
- `vllm_eval.py`: fast eval of the saved checkpoint directory.

examples/sft/gsm8k/gsm8k_sft.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import asyncio
2+
import os
3+
from pathlib import Path
4+
from typing import Any, cast
5+
6+
import chz
7+
import tinker
8+
from datasets import load_dataset
9+
from tinker import types
10+
from tinker_cookbook import checkpoint_utils, cli_utils
11+
from tinker_cookbook.supervised.data import SupervisedDatasetFromHFDataset
12+
from tinker_cookbook.supervised.train import Config as TrainConfig
13+
from tinker_cookbook.supervised.train import main as train
14+
from tinker_cookbook.supervised.types import SupervisedDatasetBuilder
15+
from tinker_cookbook.tokenizer_utils import get_tokenizer
16+
17+
os.environ.setdefault("TINKER_API_KEY", "tml-dummy-key")
18+
19+
20+
@chz.chz
21+
class GSM8KDataset(SupervisedDatasetBuilder):
22+
model_name: str
23+
batch_size: int = 16
24+
max_length: int = 640
25+
seed: int = 0
26+
27+
def __call__(self):
28+
tokenizer = get_tokenizer(self.model_name)
29+
dataset = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=self.seed)
30+
31+
def make_datum(row: dict) -> tinker.Datum:
32+
prompt = tokenizer.encode(f"Question: {row['question']}\nAnswer:", add_special_tokens=False)
33+
completion = tokenizer.encode(" " + row["answer"].strip(), add_special_tokens=False) + [tokenizer.eos_token_id]
34+
tokens = (prompt + completion)[: self.max_length]
35+
weights = ([0] * len(prompt) + [1] * len(completion))[: self.max_length]
36+
return types.Datum(
37+
model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
38+
loss_fn_inputs=cast(Any, {"target_tokens": tokens[1:], "weights": [float(w) for w in weights[1:]]}),
39+
)
40+
41+
return SupervisedDatasetFromHFDataset(dataset, self.batch_size, map_fn=make_datum), None
42+
43+
44+
@chz.chz
45+
class Config:
46+
base_model: str = "Qwen/Qwen2.5-0.5B"
47+
base_url: str = os.getenv("TINKER_BASE_URL", os.getenv("BASE_URL", "http://127.0.0.1:9003"))
48+
log_path: str = str(Path(__file__).with_name("artifacts") / "gsm8k_sft")
49+
epochs: int = 1
50+
batch: int = 16
51+
lr: float = 2e-5
52+
rank: int = 32
53+
max_len: int = 640
54+
seed: int = 0
55+
max_steps: int | None = None
56+
save_every: int = 0
57+
behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "delete"
58+
59+
60+
def main(config: Config) -> None:
61+
cli_utils.check_log_dir(config.log_path, behavior_if_exists=config.behavior_if_log_dir_exists)
62+
asyncio.run(
63+
train(
64+
TrainConfig(
65+
log_path=config.log_path,
66+
model_name=config.base_model,
67+
dataset_builder=GSM8KDataset(
68+
model_name=config.base_model,
69+
batch_size=config.batch,
70+
max_length=config.max_len,
71+
seed=config.seed,
72+
),
73+
learning_rate=config.lr,
74+
lr_schedule="cosine",
75+
num_epochs=config.epochs,
76+
lora_rank=config.rank,
77+
base_url=config.base_url,
78+
save_every=config.save_every,
79+
eval_every=0,
80+
infrequent_eval_every=0,
81+
max_steps=config.max_steps,
82+
)
83+
)
84+
)
85+
checkpoint = checkpoint_utils.get_last_checkpoint(config.log_path, required_key="sampler_path")
86+
if checkpoint is None:
87+
checkpoint = checkpoint_utils.get_last_checkpoint(config.log_path, required_key="state_path")
88+
if checkpoint is not None:
89+
path = checkpoint.sampler_path or checkpoint.state_path
90+
if path and path.startswith("tinker://"):
91+
path = str(Path(os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")) / "sampler_full" / path.removeprefix("tinker://"))
92+
if path:
93+
print(f"eval_model_path={path}")
94+
95+
96+
if __name__ == "__main__":
97+
chz.nested_entrypoint(main, allow_hyphens=True)

examples/sft/gsm8k/vllm_eval.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import argparse
2+
import json
3+
import re
4+
import time
5+
6+
ANS_RE = re.compile(r"-?\d[\d,]*")
7+
8+
9+
def extract(text: str) -> str | None:
10+
text = re.split(r"\n\s*Question:", text)[0]
11+
if "####" in text:
12+
match = ANS_RE.search(text.split("####")[-1])
13+
if match:
14+
return match.group(0).replace(",", "")
15+
numbers = ANS_RE.findall(text)
16+
return numbers[-1].replace(",", "") if numbers else None
17+
18+
19+
def main() -> None:
20+
from vllm import LLM, SamplingParams
21+
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--path", required=True)
24+
parser.add_argument("--data", default="gsm8k_test.json")
25+
args = parser.parse_args()
26+
27+
with open(args.data) as f:
28+
data = json.load(f)
29+
30+
llm = LLM(model=args.path, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=1024, enforce_eager=True)
31+
sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop=["\nQuestion:"])
32+
start = time.time()
33+
outputs = llm.generate([datum["prompt"] for datum in data], sampling_params)
34+
elapsed = time.time() - start
35+
correct = sum(int(extract(output.outputs[0].text) == datum["gold"]) for datum, output in zip(data, outputs, strict=True))
36+
37+
print("***************************************************************")
38+
print(f"[VLLM] {args.path} 0-shot GSM8K acc = {correct / len(data):.1%} on {len(data)} problems in {elapsed:.1f}s")
39+
print("***************************************************************")
40+
41+
42+
if __name__ == "__main__":
43+
main()

0 commit comments

Comments
 (0)