Skip to content

Commit 99d9970

Browse files
Tarkeshwarclaude
andcommitted
Weight decay on Muon (MUON_WEIGHT_DECAY=0.09 default in frontier run)
Frontier records (PR openai#1285 MuonEq-R + WD=0.090, PR openai#1218 WD=0.085) use AdamW-style decoupled weight decay on the Muon optimizer. Add the knob with default 0.0 (backward-compatible). Applied as p.data.mul_(1 - lr * wd) before the Muon matrix update. MuonEq-R (row-normalized) variant is not ported — it would need more line budget than we have on this branch. WD alone accounts for the majority of that record's improvement per the commit notes. dev/run_frontier.sh sets MUON_WEIGHT_DECAY=0.09 by default. Also inlined restore_low_dim_params_to_fp32 at its single call site to free lines for this change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0794f04 commit 99d9970

3 files changed

Lines changed: 23 additions & 30 deletions

File tree

dev/run_frontier.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ NUM_HEADS="${NUM_HEADS:-8}" \
2020
NUM_KV_HEADS="${NUM_KV_HEADS:-4}" \
2121
MLP_MULT="${MLP_MULT:-2}" \
2222
TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" \
23+
MUON_WEIGHT_DECAY="${MUON_WEIGHT_DECAY:-0.09}" \
2324
TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" \
2425
QK_GAIN_INIT="${QK_GAIN_INIT:-5.25}" \
2526
PARALLEL_RESIDUALS="${PARALLEL_RESIDUALS:-1}" \

records/track_10min_16mb/2026-04_tns15june_v1/train_gpt.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class Hyperparameters:
123123
matrix_lr = float(os.environ.get("MATRIX_LR", 0.04))
124124
scalar_lr = float(os.environ.get("SCALAR_LR", 0.04))
125125
muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95))
126+
muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) # frontier records use ~0.09
126127
muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
127128
muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85))
128129
muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500))
@@ -155,11 +156,8 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -
155156

156157

157158
class Muon(torch.optim.Optimizer):
158-
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True):
159-
super().__init__(
160-
params,
161-
dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov),
162-
)
159+
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0):
160+
super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay))
163161

164162
@torch.no_grad()
165163
def step(self, closure=None):
@@ -204,9 +202,12 @@ def step(self, closure=None):
204202
if distributed:
205203
dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
206204

205+
wd = group.get("weight_decay", 0.0)
207206
curr = 0
208207
for p in params:
209208
g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype)
209+
if wd != 0.0:
210+
p.data.mul_(1.0 - lr * wd)
210211
p.add_(g, alpha=-lr)
211212
curr += p.numel()
212213

@@ -763,14 +764,6 @@ def forward(self, x: Tensor) -> Tensor:
763764
return F.linear(x, w, bias)
764765

765766

766-
def restore_low_dim_params_to_fp32(module: nn.Module) -> None:
767-
# Keep small/control parameters in fp32 even when the model body runs in bf16.
768-
with torch.no_grad():
769-
for name, param in module.named_parameters():
770-
if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32:
771-
param.data = param.data.float()
772-
773-
774767
class Rotary(nn.Module):
775768
# Caches cos/sin tables per sequence length on the current device.
776769
def __init__(self, dim: int, base: float = 10000.0):
@@ -1124,7 +1117,10 @@ def log0(msg: str, console: bool = True) -> None:
11241117
for module in base_model.modules():
11251118
if isinstance(module, CastedLinear):
11261119
module.float()
1127-
restore_low_dim_params_to_fp32(base_model)
1120+
with torch.no_grad():
1121+
for name, param in base_model.named_parameters():
1122+
if (param.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32:
1123+
param.data = param.data.float()
11281124
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
11291125
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
11301126

@@ -1157,7 +1153,7 @@ def log0(msg: str, console: bool = True) -> None:
11571153
token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr
11581154
adam_kw = dict(betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
11591155
optimizer_tok = torch.optim.Adam([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], **adam_kw)
1160-
optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps)
1156+
optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay)
11611157
for group in optimizer_muon.param_groups:
11621158
group["base_lr"] = args.matrix_lr
11631159
optimizer_scalar = torch.optim.Adam([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], **adam_kw)

train_gpt.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class Hyperparameters:
123123
matrix_lr = float(os.environ.get("MATRIX_LR", 0.04))
124124
scalar_lr = float(os.environ.get("SCALAR_LR", 0.04))
125125
muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95))
126+
muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) # frontier records use ~0.09
126127
muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
127128
muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85))
128129
muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500))
@@ -155,11 +156,8 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -
155156

156157

157158
class Muon(torch.optim.Optimizer):
158-
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True):
159-
super().__init__(
160-
params,
161-
dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov),
162-
)
159+
def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0):
160+
super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay))
163161

164162
@torch.no_grad()
165163
def step(self, closure=None):
@@ -204,9 +202,12 @@ def step(self, closure=None):
204202
if distributed:
205203
dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
206204

205+
wd = group.get("weight_decay", 0.0)
207206
curr = 0
208207
for p in params:
209208
g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype)
209+
if wd != 0.0:
210+
p.data.mul_(1.0 - lr * wd)
210211
p.add_(g, alpha=-lr)
211212
curr += p.numel()
212213

@@ -763,14 +764,6 @@ def forward(self, x: Tensor) -> Tensor:
763764
return F.linear(x, w, bias)
764765

765766

766-
def restore_low_dim_params_to_fp32(module: nn.Module) -> None:
767-
# Keep small/control parameters in fp32 even when the model body runs in bf16.
768-
with torch.no_grad():
769-
for name, param in module.named_parameters():
770-
if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32:
771-
param.data = param.data.float()
772-
773-
774767
class Rotary(nn.Module):
775768
# Caches cos/sin tables per sequence length on the current device.
776769
def __init__(self, dim: int, base: float = 10000.0):
@@ -1124,7 +1117,10 @@ def log0(msg: str, console: bool = True) -> None:
11241117
for module in base_model.modules():
11251118
if isinstance(module, CastedLinear):
11261119
module.float()
1127-
restore_low_dim_params_to_fp32(base_model)
1120+
with torch.no_grad():
1121+
for name, param in base_model.named_parameters():
1122+
if (param.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32:
1123+
param.data = param.data.float()
11281124
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
11291125
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
11301126

@@ -1157,7 +1153,7 @@ def log0(msg: str, console: bool = True) -> None:
11571153
token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr
11581154
adam_kw = dict(betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
11591155
optimizer_tok = torch.optim.Adam([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], **adam_kw)
1160-
optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps)
1156+
optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay)
11611157
for group in optimizer_muon.param_groups:
11621158
group["base_lr"] = args.matrix_lr
11631159
optimizer_scalar = torch.optim.Adam([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], **adam_kw)

0 commit comments

Comments
 (0)