diff --git a/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/submission.json b/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/submission.json new file mode 100644 index 0000000000..3bc0023e8c --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/submission.json @@ -0,0 +1,10 @@ +{ + "name": "SP4096 + Polar Express + MuonEq-R + Depth Recurrence + WD=0.105", + "val_bpb": 1.0923, + "bytes_total": 15694101, + "blurb": "On clarkkev PR #1218 SP4096 base: Polar Express NS 4-step, MuonEq-R, depth recurrence layers 3,4,5 (shared MLP weights), WD=0.105, MLR=0.022. 3-seed mean 1.0923 (1337=1.0927, 42=1.0917, 2025=1.0925). Clean — no SLOT, no TTT.", + "author": "Omri Gotlieb", + "github_id": "Omrigotlieb", + "date": "2026-04-04", + "seeds": {"1337": 1.0927, "42": 1.0917, "2025": 1.0925} +} diff --git a/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/train_gpt.py b/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/train_gpt.py new file mode 100644 index 0000000000..aef795b4ac --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_V2_SP4096_DepthRecur/train_gpt.py @@ -0,0 +1,1243 @@ +import copy,glob,io,lzma,math,os,random,subprocess,sys,time,uuid +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch,torch.distributed as dist,torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn +from flash_attn_interface import flash_attn_func as flash_attn_3_func +_E=os.environ.get +class Hyperparameters(): + data_dir=_E('DATA_DIR','./data/') + seed=int(_E('SEED',1337)) + run_id=_E("RUN_ID",str(uuid.uuid4())) + iterations=int(_E('ITERATIONS',20000)) + warmdown_frac=float(_E('WARMDOWN_FRAC',0.667)) + warmup_steps=int(_E('WARMUP_STEPS',20)) + train_batch_tokens=int(_E('TRAIN_BATCH_TOKENS',2048*48*8)) + train_seq_len=int(_E('TRAIN_SEQ_LEN',2048)) + eval_seq_len=int(_E('EVAL_SEQ_LEN',2048)) + max_wallclock_seconds=float(_E('MAX_WALLCLOCK_SECONDS',600.0)) + train_log_every=int(_E('TRAIN_LOG_EVERY',500)) + val_batch_tokens=int(_E('VAL_BATCH_TOKENS',2048*32*8)) + val_loss_every=int(_E('VAL_LOSS_EVERY',4000)) + sliding_window_enabled=bool(int(_E('SLIDING_WINDOW_ENABLED','1'))) + vocab_size=int(_E('VOCAB_SIZE',4096)) + num_layers=int(_E('NUM_LAYERS',11)) + xsa_last_n=int(_E('XSA_LAST_N',11)) + num_kv_heads=int(_E('NUM_KV_HEADS',4)) + model_dim=int(_E('MODEL_DIM',512)) + embedding_dim=int(_E('EMBEDDING_DIM',512)) + num_heads=int(_E('NUM_HEADS',8)) + mlp_mult=float(_E('MLP_MULT',4.0)) + skip_gates_enabled=bool(int(_E('SKIP_GATES_ENABLED','1'))) + tie_embeddings=bool(int(_E('TIE_EMBEDDINGS','1'))) + logit_softcap=float(_E('LOGIT_SOFTCAP',30.0)) + rope_base=float(_E('ROPE_BASE',10000.0)) + rope_dims=int(_E('ROPE_DIMS',16)) + rope_train_seq_len=int(_E('ROPE_TRAIN_SEQ_LEN',2048)) + ln_scale=bool(int(_E('LN_SCALE','1'))) + ve_enabled=bool(int(_E('VE_ENABLED','1'))) + ve_dim=int(_E('VE_DIM',128)) + ve_layers=_E('VE_LAYERS','9,10') + qk_gain_init=float(_E('QK_GAIN_INIT',4.0)) + min_lr=float(_E('MIN_LR',0.0)) + embed_lr=float(_E('EMBED_LR',0.6)) + head_lr=float(_E('HEAD_LR',0.008)) + tied_embed_lr=float(_E('TIED_EMBED_LR',0.03)) + tied_embed_init_std=float(_E('TIED_EMBED_INIT_STD',0.005)) + matrix_lr=float(_E('MATRIX_LR',0.02)) + scalar_lr=float(_E('SCALAR_LR',0.02)) + muon_momentum=float(_E('MUON_MOMENTUM',0.99)) + muon_backend_steps=int(_E('MUON_BACKEND_STEPS',4)) + muon_momentum_warmup_start=float(_E('MUON_MOMENTUM_WARMUP_START',0.92)) + muon_momentum_warmup_steps=int(_E('MUON_MOMENTUM_WARMUP_STEPS',1500)) + beta1=float(_E('BETA1',0.9)) + beta2=float(_E('BETA2',0.95)) + adam_eps=float(_E('ADAM_EPS',1e-8)) + grad_clip_norm=float(_E('GRAD_CLIP_NORM',0.3)) + eval_stride=int(_E('EVAL_STRIDE',64)) + muon_beta2=float(_E('MUON_BETA2',0.95)) + adam_wd=float(_E('ADAM_WD',0.02)) + muon_wd=float(_E('MUON_WD',0.090)) + embed_wd=float(_E('EMBED_WD',0.090)) + ema_decay=float(_E('EMA_DECAY',0.997)) + recur_layers=_E('RECUR_LAYERS','') + recur_start_step=int(_E('RECUR_START_STEP',3000)) + slot_enabled=bool(int(_E('SLOT_ENABLED','0'))) + slot_steps=int(_E('SLOT_STEPS',8)) + slot_lr=float(_E('SLOT_LR',0.005)) + compressor=_E('COMPRESSOR','brotli') + gptq_enabled=bool(int(_E('GPTQ_ENABLED','1'))) + gptq_calibration_batches=int(_E('GPTQ_CALIBRATION_BATCHES',64)) + gptq_reserve_seconds=float(_E('GPTQ_RESERVE_SECONDS',10.0)) + distributed="RANK" in os.environ and "WORLD_SIZE" in os.environ + rank=int(_E("RANK","0")) + world_size=int(_E("WORLD_SIZE","1")) + local_rank=int(_E("LOCAL_RANK","0")) + is_main_process=rank==0 + grad_accum_steps=8//world_size + datasets_dir=os.path.join(data_dir,'datasets',f'fineweb10B_sp{vocab_size}') + train_files=os.path.join(datasets_dir,'fineweb_train_*.bin') + val_files=os.path.join(datasets_dir,'fineweb_val_*.bin') + tokenizer_path=os.path.join(data_dir,'tokenizers',f'fineweb_{vocab_size}_bpe.model') + logfile=f"logs/{run_id}.txt" + model_path="final_model.pt" + quantized_model_path="final_model.int6.ptz" +_logger_hparams=None +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams=h +def log(msg,console=True): + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile,"a",encoding="utf-8") as f: + print(msg,file=f) +class ValidationData: + def __init__(self,h,device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size())!=h.vocab_size: + raise ValueError(f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}") + self.val_tokens=_load_val_tokens(h.val_files,h.eval_seq_len) + self.base_bytes_lut,self.has_leading_space_lut,self.is_boundary_token_lut=_build_sp_luts(self.sp,h.vocab_size,device) +def _build_sp_luts(sp,vocab_size,device): + sv=int(sp.vocab_size()) + assert sp.piece_to_id("\u2581")!=sp.unk_id(),"Tokenizer must have \u2581 as its own token" + ts=max(sv,vocab_size) + bb=np.zeros((ts,),dtype=np.int16) + hl=np.zeros((ts,),dtype=np.bool_) + ib=np.ones((ts,),dtype=np.bool_) + for tid in range(sv): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): + continue + ib[tid]=False + if sp.is_byte(tid): + bb[tid]=1 + continue + piece=sp.id_to_piece(tid) + if piece.startswith("\u2581"): + hl[tid]=True + piece=piece[1:] + bb[tid]=len(piece.encode("utf-8")) + return (torch.tensor(bb,dtype=torch.int16,device=device), + torch.tensor(hl,dtype=torch.bool,device=device), + torch.tensor(ib,dtype=torch.bool,device=device)) +def _load_val_tokens(pattern,seq_len): + files=[Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([_load_shard(f) for f in files]).contiguous() + usable=((tokens.numel()-1)//seq_len)*seq_len + if usable<=0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[:usable+1] +def _load_shard(file): + hb=256*np.dtype("0 else 0 + bc=(nt-1-ph)//sl + self._cp[si]=ph + self._cbc[si]=bc + self._cn[si]=0 + self._cs[si]=int(self._rng.integers(bc)) if bc>1 else 0 + self._cst[si]=self._coprime(bc) + self._ci[si]=True + def _ensure(self,si,sl): + if not self._ci[si] or self._cn[si]>=self._cbc[si]: + self._reset(si,sl) + def _take(self,si,sl,count,out): + rem=count + while rem>0: + self._ensure(si,sl) + bc=int(self._cbc[si]) + ni=int(self._cn[si]) + take=min(rem,bc-ni) + ph=int(self._cp[si]) + st=int(self._cs[si]) + stride=int(self._cst[si]) + for j in range(take): + bi=(st+(ni+j)*stride)%bc + out.append((si,ph+bi*sl)) + self._cn[si]=ni+take + rem-=take + def _init_pipe(self,gt,sl,gas): + lt=gt//(self.world_size*gas) + ns=lt//sl + gns=ns*self.world_size + self._cfg=(lt,sl,ns,gns) + bbc=(self._nt-1)//sl + el=bbc>0 + self._esh=np.nonzero(el)[0].astype(np.int64) + self._bbc=bbc[self._esh].astype(np.int64) + def _sample_gw(self): + assert self._cfg is not None and self._esh is not None + _,sl,_,gns=self._cfg + ec=int(self._esh.size) + prog=min(self._bb/1800.0,1.0) + rem=np.empty(ec,dtype=np.float64) + for i,si in enumerate(self._esh.tolist()): + if self._ci[si]: + r=int(self._cbc[si])-int(self._cn[si]) + rem[i]=float(max(r,1)) + else: + rem[i]=float(self._bbc[i]) + alpha=0.90-0.40*prog + w=np.power(rem,alpha) + ws=float(w.sum()) + if not np.isfinite(ws) or ws<=0.0: + w=np.ones(ec,dtype=np.float64) + ws=float(w.sum()) + pr=w/ws + lo=min(max(8,self.world_size),ec,gns) + hi=min(max(32,self.world_size*8),ec,gns) + mix=max(1,min(int(round(lo+prog*(hi-lo))),ec,gns)) + cp=self._rng.choice(ec,size=mix,replace=False,p=pr) + cs=self._esh[cp] + cpr=pr[cp].copy() + cpr/=cpr.sum() + counts=np.ones(mix,dtype=np.int64) + extra=gns-mix + if extra>0: + counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix) + cs,counts=cs[perm],counts[perm] + bkts:list[list[tuple[int,int]]]=[] + for si,cnt in zip(cs.tolist(),counts.tolist()): + b:list[tuple[int,int]]=[] + self._take(int(si),sl,int(cnt),b) + if b: + if len(b)>1: + bp=self._rng.permutation(len(b)) + b=[b[int(k)] for k in bp.tolist()] + bkts.append(b) + wins:list[tuple[int,int]]=[] + active=[i for i,bk in enumerate(bkts) if bk] + while active: + order=self._rng.permutation(len(active)) + na:list[int]=[] + for oi in order.tolist(): + bi=active[oi] + if bkts[bi]: + wins.append(bkts[bi].pop()) + if bkts[bi]: + na.append(bi) + active=na + return wins + def next_batch(self,gt,sl,gas): + if self._cfg is None: + self._init_pipe(gt,sl,gas) + _,_,ns,_=self._cfg + gw=self._sample_gw() + lw=gw[self.rank::self.world_size] + x=torch.empty((ns,sl),dtype=torch.int64) + y=torch.empty((ns,sl),dtype=torch.int64) + for slot,(si,pos) in enumerate(lw): + mm=_get_mm(self.files[si]) + win=torch.as_tensor(np.array(mm[pos:pos+sl+1],dtype=np.int64)) + x[slot]=win[:-1] + y[slot]=win[1:] + self._bb+=1 + return x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self,eps=None): + super().__init__() + self.eps=eps + def forward(self,x): + return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x): + w=self.weight.to(x.dtype) + b=self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x,w,b) +class Rotary(nn.Module): + def __init__(self,dim,base=10000.0,train_seq_len=1024,rope_dims=0): + super().__init__() + self.dim=dim + self.base=base + self.train_seq_len=train_seq_len + self.rope_dims=rope_dims if rope_dims>0 else dim + inv_freq=1.0/(base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims)) + self.register_buffer("inv_freq",inv_freq,persistent=False) + self._sl=0 + self._cc=None + self._sc=None + def forward(self,seq_len,device,dtype): + if self._cc is None or self._sc is None or self._sl!=seq_len or self._cc.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len: + scale=seq_len/self.train_seq_len + nb=self.base*(scale**(rd/(rd-2))) + inv_freq=1.0/(nb**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd)) + else: + inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype) + freqs=torch.outer(t,inv_freq) + self._cc=freqs.cos()[None,:,None,:] + self._sc=freqs.sin()[None,:,None,:] + self._sl=seq_len + return self._cc.to(dtype=dtype),self._sc.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims