@@ -52,7 +52,7 @@ class Hyperparameters:
5252
5353 # Training length.
5454 iterations = int (os .environ .get ("ITERATIONS" , 20000 ))
55- warmdown_iters = int (os .environ .get ("WARMDOWN_ITERS" , 1200 ))
55+ warmdown_iters = int (os .environ .get ("WARMDOWN_ITERS" , 3500 ))
5656 warmup_steps = int (os .environ .get ("WARMUP_STEPS" , 20 ))
5757 train_batch_tokens = int (os .environ .get ("TRAIN_BATCH_TOKENS" , 524_288 ))
5858 train_seq_len = int (os .environ .get ("TRAIN_SEQ_LEN" , 1024 ))
@@ -84,6 +84,8 @@ class Hyperparameters:
8484 beta1 = float (os .environ .get ("BETA1" , 0.9 ))
8585 beta2 = float (os .environ .get ("BETA2" , 0.95 ))
8686 adam_eps = float (os .environ .get ("ADAM_EPS" , 1e-8 ))
87+ muon_wd = float (os .environ .get ("MUON_WD" , 0.04 ))
88+ adam_embed_wd = float (os .environ .get ("ADAM_EMBED_WD" , 0.01 ))
8789 grad_clip_norm = float (os .environ .get ("GRAD_CLIP_NORM" , 0.0 ))
8890
8991# -----------------------------
@@ -110,10 +112,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -
110112
111113
112114class Muon (torch .optim .Optimizer ):
113- def __init__ (self , params , lr : float , momentum : float , backend_steps : int , nesterov : bool = True ):
115+ def __init__ (self , params , lr : float , momentum : float , backend_steps : int , nesterov : bool = True , wd : float = 0.0 ):
114116 super ().__init__ (
115117 params ,
116- dict (lr = lr , momentum = momentum , backend_steps = backend_steps , nesterov = nesterov ),
118+ dict (lr = lr , momentum = momentum , backend_steps = backend_steps , nesterov = nesterov , wd = wd ),
117119 )
118120
119121 @torch .no_grad ()
@@ -135,6 +137,7 @@ def step(self, closure=None):
135137 momentum = group ["momentum" ]
136138 backend_steps = group ["backend_steps" ]
137139 nesterov = group ["nesterov" ]
140+ wd = group .get ("wd" , 0.0 )
138141
139142 total_params = sum (int (p .numel ()) for p in params )
140143 updates_flat = torch .zeros (total_params , device = params [0 ].device , dtype = torch .bfloat16 )
@@ -162,6 +165,8 @@ def step(self, closure=None):
162165 curr = 0
163166 for p in params :
164167 g = updates_flat [curr : curr + p .numel ()].view_as (p ).to (dtype = p .dtype )
168+ if wd > 0 :
169+ p .mul_ (1.0 - lr * wd )
165170 p .add_ (g , alpha = - lr )
166171 curr += p .numel ()
167172
@@ -613,8 +618,8 @@ def __init__(self, dim: int, mlp_mult: int):
613618 self .proj ._zero_init = True
614619
615620 def forward (self , x : Tensor ) -> Tensor :
616- x = torch . relu (self .fc (x ))
617- return self .proj (x . square () )
621+ x = F . leaky_relu (self .fc (x ), negative_slope = 0.5 )
622+ return self .proj (x * x )
618623
619624
620625class Block (nn .Module ):
@@ -873,6 +878,7 @@ def log0(msg: str, console: bool = True) -> None:
873878 lr = args .matrix_lr ,
874879 momentum = args .muon_momentum ,
875880 backend_steps = args .muon_backend_steps ,
881+ wd = args .muon_wd ,
876882 )
877883 for group in optimizer_muon .param_groups :
878884 group ["base_lr" ] = args .matrix_lr
@@ -1031,6 +1037,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10311037 torch .nn .utils .clip_grad_norm_ (base_model .parameters (), args .grad_clip_norm )
10321038 for opt in optimizers :
10331039 opt .step ()
1040+ if args .adam_embed_wd > 0 :
1041+ with torch .no_grad ():
1042+ base_model .tok_emb .weight .mul_ (1.0 - token_lr * scale * args .adam_embed_wd )
10341043 zero_grad_all ()
10351044
10361045 step += 1
0 commit comments