Skip to content

Commit caccc9e

Browse files
committed
Apply Tier 2 research findings to train_gpt.py
- Warmdown 1200 → 3500 (proven by both our research and openai#2 leaderboard entry) - Muon weight decay WD=0.04 (proven at both Tier 1 and Tier 2 scales) - Adam embedding weight decay WD=0.01 (proven to stack with Muon WD) - LeakyReLU(0.5) activation (used by openai#1 leaderboard entry) Made-with: Cursor
1 parent 48dd9ac commit caccc9e

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

train_gpt.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class Hyperparameters:
5252

5353
# Training length.
5454
iterations = int(os.environ.get("ITERATIONS", 20000))
55-
warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200))
55+
warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500))
5656
warmup_steps = int(os.environ.get("WARMUP_STEPS", 20))
5757
train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288))
5858
train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024))
@@ -84,6 +84,8 @@ class Hyperparameters:
8484
beta1 = float(os.environ.get("BETA1", 0.9))
8585
beta2 = float(os.environ.get("BETA2", 0.95))
8686
adam_eps = float(os.environ.get("ADAM_EPS", 1e-8))
87+
muon_wd = float(os.environ.get("MUON_WD", 0.04))
88+
adam_embed_wd = float(os.environ.get("ADAM_EMBED_WD", 0.01))
8789
grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0))
8890

8991
# -----------------------------
@@ -110,10 +112,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -
110112

111113

112114
class Muon(torch.optim.Optimizer):
113-
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True):
115+
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0):
114116
super().__init__(
115117
params,
116-
dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov),
118+
dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd),
117119
)
118120

119121
@torch.no_grad()
@@ -135,6 +137,7 @@ def step(self, closure=None):
135137
momentum = group["momentum"]
136138
backend_steps = group["backend_steps"]
137139
nesterov = group["nesterov"]
140+
wd = group.get("wd", 0.0)
138141

139142
total_params = sum(int(p.numel()) for p in params)
140143
updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16)
@@ -162,6 +165,8 @@ def step(self, closure=None):
162165
curr = 0
163166
for p in params:
164167
g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype)
168+
if wd > 0:
169+
p.mul_(1.0 - lr * wd)
165170
p.add_(g, alpha=-lr)
166171
curr += p.numel()
167172

@@ -613,8 +618,8 @@ def __init__(self, dim: int, mlp_mult: int):
613618
self.proj._zero_init = True
614619

615620
def forward(self, x: Tensor) -> Tensor:
616-
x = torch.relu(self.fc(x))
617-
return self.proj(x.square())
621+
x = F.leaky_relu(self.fc(x), negative_slope=0.5)
622+
return self.proj(x * x)
618623

619624

620625
class Block(nn.Module):
@@ -873,6 +878,7 @@ def log0(msg: str, console: bool = True) -> None:
873878
lr=args.matrix_lr,
874879
momentum=args.muon_momentum,
875880
backend_steps=args.muon_backend_steps,
881+
wd=args.muon_wd,
876882
)
877883
for group in optimizer_muon.param_groups:
878884
group["base_lr"] = args.matrix_lr
@@ -1031,6 +1037,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10311037
torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm)
10321038
for opt in optimizers:
10331039
opt.step()
1040+
if args.adam_embed_wd > 0:
1041+
with torch.no_grad():
1042+
base_model.tok_emb.weight.mul_(1.0 - token_lr * scale * args.adam_embed_wd)
10341043
zero_grad_all()
10351044

10361045
step += 1

0 commit comments

Comments
 (0)