forked from openai/parameter-golf
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodal_train.py
More file actions
85 lines (73 loc) · 2.17 KB
/
modal_train.py
File metadata and controls
85 lines (73 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# modal launcher for parameter-golf autoresearch.
#
# usage:
# modal run modal_train.py
#
# custom env vars:
# modal run modal_train.py --env "ITERATIONS=5000,VAL_LOSS_EVERY=200"
import modal
app = modal.App("parameter-golf")
# base image with deps + cached data + local train_gpt.py mounted
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(
"numpy",
"tqdm",
"torch==2.10",
"huggingface-hub",
"setuptools",
"typing-extensions==4.15.0",
"datasets",
"tiktoken",
"sentencepiece",
"zstandard",
)
.apt_install("git")
.run_commands(
"git clone https://github.com/openai/parameter-golf.git /opt/parameter-golf",
"cd /opt/parameter-golf && python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80",
)
# mount local train_gpt.py so agent edits get picked up each run
.add_local_file("train_gpt.py", "/opt/parameter-golf/train_gpt.py")
)
@app.function(
image=image,
gpu="H100:8",
timeout=3600,
)
def train(env_overrides: dict[str, str] | None = None):
"""8xh100 training"""
import os
import subprocess
# try to install flash-attn at runtime (may timeout)
subprocess.run(
["pip", "install", "flash-attn", "--no-build-isolation", "-q"],
capture_output=True, timeout=120,
)
os.chdir("/opt/parameter-golf")
env = os.environ.copy()
env.update({
"DATA_PATH": "./data/datasets/fineweb10B_sp1024",
"TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model",
"VOCAB_SIZE": "1024",
"RUN_ID": "modal_run",
})
if env_overrides:
env.update(env_overrides)
result = subprocess.run(
["torchrun", "--standalone", "--nproc_per_node=8", "train_gpt.py"],
env=env,
)
return result.returncode
@app.local_entrypoint()
def main(
env: str = "",
):
env_overrides = {}
if env:
for e in env.split(","):
k, v = e.split("=", 1)
env_overrides[k] = v
print("launching 8xh100 training...")
rc = train.remote(env_overrides or None)
print(f"training finished with exit code: {rc}")