From 2e1278eb8334283650aca26d460991fd8c87640a Mon Sep 17 00:00:00 2001 From: SPThole Date: Tue, 24 Mar 2026 18:09:48 +0530 Subject: [PATCH 01/10] updated sub --- .../.DS_Store | Bin 0 -> 8196 bytes .../README.md | 93 + .../logs/.DS_Store | Bin 0 -> 6148 bytes ...-cyclic-relusq-11Lshared_8xH100_seed42.txt | 1520 ++++++++++++++++ ...-cyclic-relusq-11Lshared_8xH100_seed43.txt | 1522 +++++++++++++++++ ...-cyclic-relusq-11Lshared_8xH100_seed44.txt | 1521 ++++++++++++++++ .../2025-03-24_AWQ_CyclMom_11L_shared/run.sh | 79 + .../setup_and_run.sh | 61 + .../submission.json | 11 + .../train_gpt.py | 1340 +++++++++++++++ 10 files changed, 6147 insertions(+) create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/.DS_Store create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed42.txt create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt create mode 100755 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh create mode 100755 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json create mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..eac9a8f0ea0e0cca587d8e90d848a98e92deabca GIT binary patch literal 8196 zcmeHMO-~a+7=EW#+zQAdAEMD@V`E|hB8nj)##jqR(HICN3Vzky?obw&nPzuOMMBcV znGn~3jCJ}i06ZYMPN>2U7>t*U?q4Qw5J#p1$tg6oQNSp$Qvq>yPeTYC$`Efm zzvr=U_@^{UZ032jxT4Sb@4QMF#L39WXOh~N-rsh>I$*V14+=}XR!H2W8t2{Uo@mW+ z9(oPm^==2WRkpJK}WI|1`~EzhgWk0rODI;@P9DcS?;>*vPK4>_ab zV;e)x`uO>gA?L#A*v3Z2>dT(JI9q%it}y;wEtTlU1$3%*vwneIb3H+6Rj-G9YOLq0 zZ8uhwIoRIOd8o_oKHSsY+uhgK*MDT-=&_^6i*}FamL8WQaitSHD>&-v93jwvU}wy@TV6SW17sj_zOk5D=tMm3}<*8 zBmrj&_hS@!DxoV=Y?0%Pt0?evMRditX%r*w;LJY=U5|#-EYFjKuPnGxN$FBmK-Aqjf&%i8P#aSoe7UD3hfWPhI{~ohOT0Z`55J43f=EktJHJ6;i z+%qk?Hs(-BFuwvtaFOLB?#fT~Aiwpi)`k2kcz9I-dNI~q&b*91dRu2tJ=@0az?mbQ zH9$@gzt|OQ4`aQym4|u*yl%N_Y_^(^$-yEpr?IY3ghX8K!iZ-77lIfw3hW^T_Ni+f z5a<7^yZ`>bhgM`R)hJ*T_)`T)da^JnUiRg#{u{Q~H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 24 11:24:37 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 42C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 41C P0 130W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 42C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 41C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9334 train_time:148ms step_avg:148.19ms +step:2/20000 train_loss:8.7695 train_time:229ms step_avg:114.65ms +step:3/20000 train_loss:8.0097 train_time:326ms step_avg:108.71ms +step:4/20000 train_loss:7.2373 train_time:423ms step_avg:105.66ms +step:5/20000 train_loss:7.1032 train_time:521ms step_avg:104.13ms +step:6/20000 train_loss:6.9653 train_time:620ms step_avg:103.34ms +step:7/20000 train_loss:6.8161 train_time:718ms step_avg:102.55ms +step:8/20000 train_loss:6.8162 train_time:816ms step_avg:101.98ms +step:9/20000 train_loss:6.5146 train_time:913ms step_avg:101.41ms +step:10/20000 train_loss:6.1109 train_time:1011ms step_avg:101.06ms +step:100/20000 train_loss:3.1632 train_time:9982ms step_avg:99.82ms +step:200/20000 train_loss:2.3797 train_time:19987ms step_avg:99.94ms +step:300/20000 train_loss:2.5390 train_time:30012ms step_avg:100.04ms +step:400/20000 train_loss:2.4119 train_time:40058ms step_avg:100.15ms +step:500/20000 train_loss:2.3950 train_time:50068ms step_avg:100.14ms +step:500/20000 val_loss:2.3560 val_bpb:1.3954 train_time:50096ms step_avg:100.19ms +step:600/20000 train_loss:2.3334 train_time:60135ms step_avg:100.23ms +step:700/20000 train_loss:2.3465 train_time:70200ms step_avg:100.29ms +step:800/20000 train_loss:2.2394 train_time:80264ms step_avg:100.33ms +step:900/20000 train_loss:2.1327 train_time:90309ms step_avg:100.34ms +step:1000/20000 train_loss:2.2778 train_time:100312ms step_avg:100.31ms +step:1000/20000 val_loss:2.2303 val_bpb:1.3209 train_time:100339ms step_avg:100.34ms +step:1100/20000 train_loss:2.3252 train_time:110366ms step_avg:100.33ms +step:1200/20000 train_loss:2.3585 train_time:120420ms step_avg:100.35ms +step:1300/20000 train_loss:2.1034 train_time:130470ms step_avg:100.36ms +step:1400/20000 train_loss:2.1910 train_time:140517ms step_avg:100.37ms +step:1500/20000 train_loss:2.2249 train_time:150502ms step_avg:100.33ms +step:1500/20000 val_loss:2.1887 val_bpb:1.2962 train_time:150528ms step_avg:100.35ms +step:1600/20000 train_loss:2.0341 train_time:160544ms step_avg:100.34ms +step:1700/20000 train_loss:2.1021 train_time:170584ms step_avg:100.34ms +step:1800/20000 train_loss:2.1251 train_time:180628ms step_avg:100.35ms +step:1900/20000 train_loss:2.1003 train_time:190601ms step_avg:100.32ms +step:2000/20000 train_loss:2.0483 train_time:200615ms step_avg:100.31ms +step:2000/20000 val_loss:2.1119 val_bpb:1.2508 train_time:200641ms step_avg:100.32ms +step:2100/20000 train_loss:2.0367 train_time:210626ms step_avg:100.30ms +step:2200/20000 train_loss:2.1896 train_time:220638ms step_avg:100.29ms +step:2300/20000 train_loss:2.1112 train_time:230668ms step_avg:100.29ms +step:2400/20000 train_loss:2.0746 train_time:240605ms step_avg:100.25ms +step:2500/20000 train_loss:2.1813 train_time:250618ms step_avg:100.25ms +step:2500/20000 val_loss:2.1205 val_bpb:1.2559 train_time:250645ms step_avg:100.26ms +step:2600/20000 train_loss:2.1164 train_time:260617ms step_avg:100.24ms +step:2700/20000 train_loss:2.1148 train_time:270603ms step_avg:100.22ms +step:2800/20000 train_loss:2.1692 train_time:280673ms step_avg:100.24ms +step:2900/20000 train_loss:2.0438 train_time:290623ms step_avg:100.21ms +step:3000/20000 train_loss:2.1738 train_time:300599ms step_avg:100.20ms +step:3000/20000 val_loss:2.1053 val_bpb:1.2469 train_time:300625ms step_avg:100.21ms +step:3100/20000 train_loss:2.0536 train_time:310605ms step_avg:100.20ms +step:3200/20000 train_loss:2.1821 train_time:320597ms step_avg:100.19ms +step:3300/20000 train_loss:2.0807 train_time:330522ms step_avg:100.16ms +step:3400/20000 train_loss:2.0249 train_time:340504ms step_avg:100.15ms +step:3500/20000 train_loss:2.1864 train_time:350488ms step_avg:100.14ms +step:3500/20000 val_loss:2.0867 val_bpb:1.2358 train_time:350515ms step_avg:100.15ms +step:3600/20000 train_loss:2.0971 train_time:360481ms step_avg:100.13ms +step:3700/20000 train_loss:2.0972 train_time:370467ms step_avg:100.13ms +step:3800/20000 train_loss:2.0745 train_time:380402ms step_avg:100.11ms +step:3900/20000 train_loss:2.0771 train_time:390409ms step_avg:100.10ms +step:4000/20000 train_loss:1.9767 train_time:400381ms step_avg:100.10ms +step:4000/20000 val_loss:2.0649 val_bpb:1.2230 train_time:400407ms step_avg:100.10ms +step:4100/20000 train_loss:2.0111 train_time:410379ms step_avg:100.09ms +step:4200/20000 train_loss:2.1512 train_time:420355ms step_avg:100.08ms +step:4300/20000 train_loss:2.0497 train_time:430284ms step_avg:100.07ms +step:4400/20000 train_loss:2.0278 train_time:440260ms step_avg:100.06ms +step:4500/20000 train_loss:2.1155 train_time:450253ms step_avg:100.06ms +step:4500/20000 val_loss:2.0376 val_bpb:1.2068 train_time:450280ms step_avg:100.06ms +step:4600/20000 train_loss:1.8352 train_time:460233ms step_avg:100.05ms +step:4700/20000 train_loss:2.2260 train_time:470153ms step_avg:100.03ms +step:4800/20000 train_loss:2.4265 train_time:480133ms step_avg:100.03ms +step:4900/20000 train_loss:2.0406 train_time:490112ms step_avg:100.02ms +step:5000/20000 train_loss:2.0948 train_time:500100ms step_avg:100.02ms +step:5000/20000 val_loss:2.0129 val_bpb:1.1922 train_time:500126ms step_avg:100.03ms +step:5100/20000 train_loss:2.1123 train_time:510082ms step_avg:100.02ms +step:5200/20000 train_loss:2.0305 train_time:520004ms step_avg:100.00ms +step:5300/20000 train_loss:1.9923 train_time:529971ms step_avg:99.99ms +swa:start step:5350 +step:5400/20000 train_loss:2.0323 train_time:540022ms step_avg:100.00ms +step:5500/20000 train_loss:2.0009 train_time:550034ms step_avg:100.01ms +step:5500/20000 val_loss:1.9842 val_bpb:1.1752 train_time:550089ms step_avg:100.02ms +step:5600/20000 train_loss:1.9365 train_time:560089ms step_avg:100.02ms +step:5700/20000 train_loss:1.9912 train_time:570062ms step_avg:100.01ms +step:5800/20000 train_loss:1.9716 train_time:580099ms step_avg:100.02ms +step:5900/20000 train_loss:1.8795 train_time:590129ms step_avg:100.02ms +step:5999/20000 val_loss:1.9594 val_bpb:1.1605 train_time:600082ms step_avg:100.03ms +stopping_early: wallclock_cap train_time:600082ms step:5999/20000 +peak memory allocated: 20841 MiB reserved: 21060 MiB +swa:applying averaged 13 checkpoints +Serialized model: 98437419 bytes +Code size: 58616 bytes +Total submission size: 98496035 bytes +awq:calibrating alpha=0.5 +awq:scaled 61 layers +Serialized model int6+zstd: 15394174 bytes +Total submission size int8+zlib: 15452790 bytes +awq:unscaled 61 layers after dequant +final_eval_mode:sliding_window stride:64 batch_seqs:64 +final_int8_zlib_roundtrip val_loss:1.9420 val_bpb:1.1502 eval_time:180059ms +final_int8_zlib_roundtrip_exact val_loss:1.94198177 val_bpb:1.15015403 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt new file mode 100644 index 0000000000..b52ee75a72 --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt @@ -0,0 +1,1522 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 24 11:39:20 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 47C P0 129W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 47C P0 134W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 48C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 38C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 46C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:43 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9285 val_bpb:4.1035 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9296 train_time:149ms step_avg:148.74ms +step:2/20000 train_loss:8.4791 train_time:221ms step_avg:110.74ms +step:3/20000 train_loss:7.8364 train_time:318ms step_avg:105.92ms +step:4/20000 train_loss:7.2856 train_time:415ms step_avg:103.85ms +step:5/20000 train_loss:7.0369 train_time:514ms step_avg:102.84ms +step:6/20000 train_loss:6.8438 train_time:612ms step_avg:102.03ms +step:7/20000 train_loss:6.7871 train_time:711ms step_avg:101.54ms +step:8/20000 train_loss:6.8742 train_time:810ms step_avg:101.19ms +step:9/20000 train_loss:6.5695 train_time:908ms step_avg:100.86ms +step:10/20000 train_loss:6.2076 train_time:1005ms step_avg:100.52ms +step:100/20000 train_loss:3.1589 train_time:10032ms step_avg:100.32ms +step:200/20000 train_loss:2.3736 train_time:20030ms step_avg:100.15ms +step:300/20000 train_loss:2.5403 train_time:30058ms step_avg:100.19ms +step:400/20000 train_loss:2.4110 train_time:40100ms step_avg:100.25ms +step:500/20000 train_loss:2.3962 train_time:50106ms step_avg:100.21ms +step:500/20000 val_loss:2.3542 val_bpb:1.3943 train_time:50133ms step_avg:100.27ms +step:600/20000 train_loss:2.3361 train_time:60176ms step_avg:100.29ms +step:700/20000 train_loss:2.3427 train_time:70239ms step_avg:100.34ms +step:800/20000 train_loss:2.2381 train_time:80296ms step_avg:100.37ms +step:900/20000 train_loss:2.1292 train_time:90346ms step_avg:100.38ms +step:1000/20000 train_loss:2.2797 train_time:100338ms step_avg:100.34ms +step:1000/20000 val_loss:2.2304 val_bpb:1.3210 train_time:100365ms step_avg:100.36ms +step:1100/20000 train_loss:2.3257 train_time:110393ms step_avg:100.36ms +step:1200/20000 train_loss:2.3582 train_time:120427ms step_avg:100.36ms +step:1300/20000 train_loss:2.1011 train_time:130471ms step_avg:100.36ms +step:1400/20000 train_loss:2.1876 train_time:140507ms step_avg:100.36ms +step:1500/20000 train_loss:2.2261 train_time:150479ms step_avg:100.32ms +step:1500/20000 val_loss:2.1915 val_bpb:1.2979 train_time:150505ms step_avg:100.34ms +step:1600/20000 train_loss:2.0294 train_time:160501ms step_avg:100.31ms +step:1700/20000 train_loss:2.1036 train_time:170525ms step_avg:100.31ms +step:1800/20000 train_loss:2.1260 train_time:180535ms step_avg:100.30ms +step:1900/20000 train_loss:2.0983 train_time:190498ms step_avg:100.26ms +step:2000/20000 train_loss:2.0471 train_time:200512ms step_avg:100.26ms +step:2000/20000 val_loss:2.1136 val_bpb:1.2518 train_time:200539ms step_avg:100.27ms +step:2100/20000 train_loss:2.0326 train_time:210518ms step_avg:100.25ms +step:2200/20000 train_loss:2.1266 train_time:220516ms step_avg:100.23ms +step:2300/20000 train_loss:2.1105 train_time:230534ms step_avg:100.23ms +step:2400/20000 train_loss:2.0723 train_time:240465ms step_avg:100.19ms +step:2500/20000 train_loss:2.1760 train_time:250447ms step_avg:100.18ms +step:2500/20000 val_loss:2.1204 val_bpb:1.2558 train_time:250473ms step_avg:100.19ms +step:2600/20000 train_loss:2.1216 train_time:260431ms step_avg:100.17ms +step:2700/20000 train_loss:2.1178 train_time:270423ms step_avg:100.16ms +step:2800/20000 train_loss:2.1704 train_time:280417ms step_avg:100.15ms +step:2900/20000 train_loss:2.0422 train_time:290367ms step_avg:100.13ms +step:3000/20000 train_loss:2.1734 train_time:300362ms step_avg:100.12ms +step:3000/20000 val_loss:2.1066 val_bpb:1.2476 train_time:300388ms step_avg:100.13ms +step:3100/20000 train_loss:2.0504 train_time:310354ms step_avg:100.11ms +step:3200/20000 train_loss:2.1826 train_time:320337ms step_avg:100.11ms +step:3300/20000 train_loss:2.0797 train_time:330258ms step_avg:100.08ms +step:3400/20000 train_loss:2.0276 train_time:340246ms step_avg:100.07ms +step:3500/20000 train_loss:2.1849 train_time:350219ms step_avg:100.06ms +step:3500/20000 val_loss:2.0866 val_bpb:1.2358 train_time:350244ms step_avg:100.07ms +step:3600/20000 train_loss:2.0964 train_time:360200ms step_avg:100.06ms +step:3700/20000 train_loss:2.0972 train_time:370186ms step_avg:100.05ms +step:3800/20000 train_loss:2.0760 train_time:380100ms step_avg:100.03ms +step:3900/20000 train_loss:2.0754 train_time:390084ms step_avg:100.02ms +step:4000/20000 train_loss:1.9785 train_time:400040ms step_avg:100.01ms +step:4000/20000 val_loss:2.0639 val_bpb:1.2224 train_time:400067ms step_avg:100.02ms +step:4100/20000 train_loss:2.0115 train_time:410015ms step_avg:100.00ms +step:4200/20000 train_loss:2.1511 train_time:419979ms step_avg:100.00ms +step:4300/20000 train_loss:2.0566 train_time:429895ms step_avg:99.98ms +step:4400/20000 train_loss:2.0267 train_time:439861ms step_avg:99.97ms +step:4500/20000 train_loss:2.1164 train_time:449828ms step_avg:99.96ms +step:4500/20000 val_loss:2.0388 val_bpb:1.2075 train_time:449853ms step_avg:99.97ms +step:4600/20000 train_loss:1.8345 train_time:459799ms step_avg:99.96ms +step:4700/20000 train_loss:2.2178 train_time:469706ms step_avg:99.94ms +step:4800/20000 train_loss:2.4194 train_time:479665ms step_avg:99.93ms +step:4900/20000 train_loss:2.0402 train_time:489639ms step_avg:99.93ms +step:5000/20000 train_loss:2.0891 train_time:499610ms step_avg:99.92ms +step:5000/20000 val_loss:2.0117 val_bpb:1.1915 train_time:499637ms step_avg:99.93ms +step:5100/20000 train_loss:2.1118 train_time:509574ms step_avg:99.92ms +step:5200/20000 train_loss:2.0281 train_time:519474ms step_avg:99.90ms +step:5300/20000 train_loss:1.9924 train_time:529438ms step_avg:99.89ms +swa:start step:5350 +step:5400/20000 train_loss:2.0311 train_time:539475ms step_avg:99.90ms +step:5500/20000 train_loss:2.0004 train_time:549484ms step_avg:99.91ms +step:5500/20000 val_loss:1.9834 val_bpb:1.1747 train_time:549536ms step_avg:99.92ms +step:5600/20000 train_loss:1.9374 train_time:559498ms step_avg:99.91ms +step:5700/20000 train_loss:1.9898 train_time:569462ms step_avg:99.91ms +step:5800/20000 train_loss:1.9719 train_time:579486ms step_avg:99.91ms +step:5900/20000 train_loss:1.8830 train_time:589542ms step_avg:99.92ms +step:6000/20000 train_loss:1.9237 train_time:599555ms step_avg:99.93ms +step:6000/20000 val_loss:1.9583 val_bpb:1.1598 train_time:599622ms step_avg:99.94ms +step:6005/20000 val_loss:1.9583 val_bpb:1.1598 train_time:600115ms step_avg:99.94ms +stopping_early: wallclock_cap train_time:600115ms step:6005/20000 +peak memory allocated: 20841 MiB reserved: 21060 MiB +swa:applying averaged 14 checkpoints +Serialized model: 98437419 bytes +Code size: 58616 bytes +Total submission size: 98496035 bytes +awq:calibrating alpha=0.5 +awq:scaled 61 layers +Serialized model int6+zstd: 15458563 bytes +Total submission size int8+zlib: 15517179 bytes +awq:unscaled 61 layers after dequant +final_eval_mode:sliding_window stride:64 batch_seqs:64 +final_int8_zlib_roundtrip val_loss:1.9408 val_bpb:1.1494 eval_time:180159ms +final_int8_zlib_roundtrip_exact val_loss:1.94077623 val_bpb:1.14944004 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt new file mode 100644 index 0000000000..69e4a3c2e0 --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt @@ -0,0 +1,1521 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 24 11:54:04 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 47C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 47C P0 134W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 48C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 38C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 47C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:44 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9280 val_bpb:4.1031 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9287 train_time:149ms step_avg:149.48ms +step:2/20000 train_loss:8.4639 train_time:222ms step_avg:111.22ms +step:3/20000 train_loss:7.8194 train_time:320ms step_avg:106.51ms +step:4/20000 train_loss:7.2740 train_time:418ms step_avg:104.46ms +step:5/20000 train_loss:6.9708 train_time:515ms step_avg:102.91ms +step:6/20000 train_loss:6.8442 train_time:613ms step_avg:102.11ms +step:7/20000 train_loss:6.7332 train_time:710ms step_avg:101.45ms +step:8/20000 train_loss:6.6899 train_time:808ms step_avg:100.97ms +step:9/20000 train_loss:6.4768 train_time:904ms step_avg:100.45ms +step:10/20000 train_loss:6.1476 train_time:1001ms step_avg:100.06ms +step:100/20000 train_loss:3.1518 train_time:9962ms step_avg:99.62ms +step:200/20000 train_loss:2.3792 train_time:19971ms step_avg:99.86ms +step:300/20000 train_loss:2.5375 train_time:30002ms step_avg:100.01ms +step:400/20000 train_loss:2.4066 train_time:40063ms step_avg:100.16ms +step:500/20000 train_loss:2.3937 train_time:50067ms step_avg:100.13ms +step:500/20000 val_loss:2.3569 val_bpb:1.3959 train_time:50091ms step_avg:100.18ms +step:600/20000 train_loss:2.3337 train_time:60134ms step_avg:100.22ms +step:700/20000 train_loss:2.3503 train_time:70193ms step_avg:100.28ms +step:800/20000 train_loss:2.2417 train_time:80256ms step_avg:100.32ms +step:900/20000 train_loss:2.1287 train_time:90313ms step_avg:100.35ms +step:1000/20000 train_loss:2.2758 train_time:100310ms step_avg:100.31ms +step:1000/20000 val_loss:2.2311 val_bpb:1.3214 train_time:100336ms step_avg:100.34ms +step:1100/20000 train_loss:2.3322 train_time:110366ms step_avg:100.33ms +step:1200/20000 train_loss:2.3606 train_time:120403ms step_avg:100.34ms +step:1300/20000 train_loss:2.1112 train_time:130456ms step_avg:100.35ms +step:1400/20000 train_loss:2.1882 train_time:140478ms step_avg:100.34ms +step:1500/20000 train_loss:2.2263 train_time:150441ms step_avg:100.29ms +step:1500/20000 val_loss:2.1920 val_bpb:1.2982 train_time:150469ms step_avg:100.31ms +step:1600/20000 train_loss:2.0377 train_time:160472ms step_avg:100.30ms +step:1700/20000 train_loss:2.1060 train_time:170501ms step_avg:100.29ms +step:1800/20000 train_loss:2.1298 train_time:180522ms step_avg:100.29ms +step:1900/20000 train_loss:2.1008 train_time:190493ms step_avg:100.26ms +step:2000/20000 train_loss:2.0515 train_time:200504ms step_avg:100.25ms +step:2000/20000 val_loss:2.1173 val_bpb:1.2540 train_time:200531ms step_avg:100.27ms +step:2100/20000 train_loss:2.0373 train_time:210546ms step_avg:100.26ms +step:2200/20000 train_loss:2.1344 train_time:220560ms step_avg:100.25ms +step:2300/20000 train_loss:2.1168 train_time:230569ms step_avg:100.25ms +step:2400/20000 train_loss:2.0723 train_time:240545ms step_avg:100.23ms +step:2500/20000 train_loss:2.1878 train_time:250547ms step_avg:100.22ms +step:2500/20000 val_loss:2.1235 val_bpb:1.2577 train_time:250575ms step_avg:100.23ms +step:2600/20000 train_loss:2.1224 train_time:260544ms step_avg:100.21ms +step:2700/20000 train_loss:2.1155 train_time:270517ms step_avg:100.19ms +step:2800/20000 train_loss:2.1707 train_time:280521ms step_avg:100.19ms +step:2900/20000 train_loss:2.0416 train_time:290459ms step_avg:100.16ms +step:3000/20000 train_loss:2.1755 train_time:300453ms step_avg:100.15ms +step:3000/20000 val_loss:2.1071 val_bpb:1.2479 train_time:300479ms step_avg:100.16ms +step:3100/20000 train_loss:2.0540 train_time:310500ms step_avg:100.16ms +step:3200/20000 train_loss:2.1885 train_time:320480ms step_avg:100.15ms +step:3300/20000 train_loss:2.0829 train_time:330417ms step_avg:100.13ms +step:3400/20000 train_loss:2.0306 train_time:340422ms step_avg:100.12ms +step:3500/20000 train_loss:2.1900 train_time:350415ms step_avg:100.12ms +step:3500/20000 val_loss:2.0899 val_bpb:1.2377 train_time:350441ms step_avg:100.13ms +step:3600/20000 train_loss:2.0978 train_time:360415ms step_avg:100.12ms +step:3700/20000 train_loss:2.1007 train_time:370408ms step_avg:100.11ms +step:3800/20000 train_loss:2.0762 train_time:380332ms step_avg:100.09ms +step:3900/20000 train_loss:2.0818 train_time:390308ms step_avg:100.08ms +step:4000/20000 train_loss:1.9747 train_time:400290ms step_avg:100.07ms +step:4000/20000 val_loss:2.0678 val_bpb:1.2247 train_time:400315ms step_avg:100.08ms +step:4100/20000 train_loss:2.0182 train_time:410281ms step_avg:100.07ms +step:4200/20000 train_loss:2.1542 train_time:420258ms step_avg:100.06ms +step:4300/20000 train_loss:2.0556 train_time:430192ms step_avg:100.04ms +step:4400/20000 train_loss:2.0332 train_time:440177ms step_avg:100.04ms +step:4500/20000 train_loss:2.1193 train_time:450138ms step_avg:100.03ms +step:4500/20000 val_loss:2.0431 val_bpb:1.2101 train_time:450164ms step_avg:100.04ms +step:4600/20000 train_loss:1.8407 train_time:460123ms step_avg:100.03ms +step:4700/20000 train_loss:2.2264 train_time:470035ms step_avg:100.01ms +step:4800/20000 train_loss:2.4262 train_time:480013ms step_avg:100.00ms +step:4900/20000 train_loss:2.0455 train_time:489990ms step_avg:100.00ms +step:5000/20000 train_loss:2.0949 train_time:499981ms step_avg:100.00ms +step:5000/20000 val_loss:2.0153 val_bpb:1.1936 train_time:500007ms step_avg:100.00ms +step:5100/20000 train_loss:2.1149 train_time:509962ms step_avg:99.99ms +step:5200/20000 train_loss:2.0314 train_time:519881ms step_avg:99.98ms +step:5300/20000 train_loss:1.9922 train_time:529849ms step_avg:99.97ms +swa:start step:5350 +step:5400/20000 train_loss:2.0348 train_time:539907ms step_avg:99.98ms +step:5500/20000 train_loss:2.0003 train_time:549943ms step_avg:99.99ms +step:5500/20000 val_loss:1.9868 val_bpb:1.1767 train_time:550013ms step_avg:100.00ms +step:5600/20000 train_loss:1.9387 train_time:559996ms step_avg:100.00ms +step:5700/20000 train_loss:1.9929 train_time:569977ms step_avg:100.00ms +step:5800/20000 train_loss:1.9749 train_time:580010ms step_avg:100.00ms +step:5900/20000 train_loss:1.8845 train_time:590024ms step_avg:100.00ms +step:6000/20000 train_loss:1.9269 train_time:600029ms step_avg:100.00ms +step:6000/20000 val_loss:1.9619 val_bpb:1.1619 train_time:600099ms step_avg:100.02ms +stopping_early: wallclock_cap train_time:600099ms step:6000/20000 +peak memory allocated: 20841 MiB reserved: 21060 MiB +swa:applying averaged 14 checkpoints +Serialized model: 98437419 bytes +Code size: 58616 bytes +Total submission size: 98496035 bytes +awq:calibrating alpha=0.5 +awq:scaled 61 layers +Serialized model int6+zstd: 15367341 bytes +Total submission size int8+zlib: 15425957 bytes +awq:unscaled 61 layers after dequant +final_eval_mode:sliding_window stride:64 batch_seqs:64 +final_int8_zlib_roundtrip val_loss:1.9460 val_bpb:1.1526 eval_time:179947ms +final_int8_zlib_roundtrip_exact val_loss:1.94603275 val_bpb:1.15255326 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh new file mode 100755 index 0000000000..8f44293924 --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# ============================================================ +# SUBMISSION: exp18 AWQ + cyclic momentum + relu_sq + 11L shared +# 8×H100 SXM, 3 seeds, full sliding window eval +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="submission_exp18_awq-cyclic-relusq-11Lshared_8xH100" +LOG_DIR="records/h100_experiments/${EXP_NAME}/logs" + +cd /workspace/parameter-golf +mkdir -p "${LOG_DIR}" + +# --- Architecture --- +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3.0 +export MLP_ACTIVATION=relu_sq +export VOCAB_SIZE=1024 +export TIE_EMBEDDINGS=1 +export LOGIT_SOFTCAP=30.0 +export BIGRAM_VOCAB_SIZE=10240 +export BIGRAM_DIM=128 + +# --- Training (8×H100 scale) --- +export ITERATIONS=20000 +export WARMUP_STEPS=20 +export WARMDOWN_ITERS=3500 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export MAX_WALLCLOCK_SECONDS=600 + +# --- Optimizer --- +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_MOMENTUM_WARMUP_START=0.92 +export MUON_MOMENTUM_WARMUP_STEPS=1500 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export GRAD_CLIP_NORM=0.3 +export WEIGHT_DECAY=0.04 + +# --- SWA --- +export SWA_ENABLED=1 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 + +# --- Validation & Eval --- +export VAL_LOSS_EVERY=500 +export VAL_BATCH_SIZE=524288 +export TRAIN_LOG_EVERY=100 +export EVAL_STRIDE=64 +export EVAL_BATCH_SEQS=64 + +# --- AWQ --- +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 + +# --- Run 3 seeds --- +for SEED in 42 43 44; do + export SEED + export RUN_ID="${EXP_NAME}_seed${SEED}" + echo "============================================" + echo "=== SEED ${SEED} ===" + echo "============================================" + torchrun --nproc_per_node=8 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${LOG_DIR}/${EXP_NAME}_seed${SEED}.log" + echo "=== SEED ${SEED} COMPLETE ===" + echo "" +done + +echo "=== ALL 3 SEEDS COMPLETE ===" diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh new file mode 100755 index 0000000000..3a7fa0e4e1 --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# ============================================================ +# FULL SETUP + RUN for 8×H100 submission +# Usage: Pass SSH host and port as arguments +# ./setup_and_run.sh +# ============================================================ +set -euo pipefail + +HOST="${1:?Usage: $0 }" +PORT="${2:?Usage: $0 }" +SSH="ssh -o StrictHostKeyChecking=no -p ${PORT} root@${HOST}" +SCP="scp -P ${PORT}" +LOCAL_BASE="/Users/sidhantthole/Documents/llama_index_exp/openaigolf/parameter-golf" +REMOTE_BASE="/workspace/parameter-golf" +EXP="submission_exp18_awq-cyclic-relusq-11Lshared_8xH100" + +echo "=== Step 1: Test connection ===" +$SSH "echo 'Connected!' && nvidia-smi --query-gpu=name,memory.total --format=csv,noheader" + +echo "=== Step 2: Clone repo + install deps ===" +$SSH "cd /workspace && rm -rf parameter-golf && git clone https://github.com/openai/parameter-golf.git && pip install --break-system-packages -q zstandard" + +echo "=== Step 3: Start data download in background ===" +$SSH "cd ${REMOTE_BASE} && nohup python3 data/cached_challenge_fineweb.py --variant sp1024 > /tmp/data_download.out 2>&1 &" + +echo "=== Step 4: Copy experiment files while data downloads ===" +$SSH "mkdir -p ${REMOTE_BASE}/records/h100_experiments" +$SCP -r "${LOCAL_BASE}/records/h100_experiments/${EXP}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/" +$SSH "chmod +x ${REMOTE_BASE}/records/h100_experiments/${EXP}/run.sh" + +echo "=== Step 5: Wait for data download ===" +while true; do + COUNT=$($SSH "ls ${REMOTE_BASE}/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l" 2>/dev/null || echo "0") + echo " Data shards: ${COUNT}/80" + if [ "$COUNT" -ge 80 ]; then + break + fi + sleep 10 +done +echo "Data download complete!" + +echo "=== Step 6: Launch submission (3 seeds) ===" +$SSH "nohup bash ${REMOTE_BASE}/records/h100_experiments/${EXP}/run.sh > /tmp/submission.out 2>&1 &" +echo "Submission launched!" + +echo "=== Step 7: Syncing records every 10 seconds ===" +while true; do + rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/${EXP}/" "${LOCAL_BASE}/records/h100_experiments/${EXP}/" 2>/dev/null + rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/logs/" "${LOCAL_BASE}/logs/" 2>/dev/null + + # Check if still running + RUNNING=$($SSH "ps aux | grep train_gpt | grep -v grep | wc -l" 2>/dev/null || echo "0") + if [ "$RUNNING" -eq 0 ]; then + # Final sync + rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/${EXP}/" "${LOCAL_BASE}/records/h100_experiments/${EXP}/" 2>/dev/null + rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/logs/" "${LOCAL_BASE}/logs/" 2>/dev/null + echo "=== ALL SEEDS COMPLETE ===" + break + fi + sleep 10 +done diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json new file mode 100644 index 0000000000..08b8e425cc --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json @@ -0,0 +1,11 @@ +{ + "author": "SPThole", + "github_id": "SPThole", + "name": "AWQ + Cyclic Momentum + ReLU² + 11L Shared", + "blurb": "Activation-Aware Weight Quantization (AWQ) closes the int5/int6 quant gap from 0.027 to 0.010 bpb. Cyclic Muon momentum (0.85-0.95 triangle wave) escapes sharp minima. ReLU² for sparser MLPs. 11 layers with 10 unique weights. Emerged from 21+ systematic experiments on 1×H100/A40 before scaling to 8×H100.", + "date": "2026-03-25T00:00:00Z", + "val_loss": 1.94293025, + "val_bpb": 1.15071578, + "bytes_total": 15465308, + "bytes_code": 58616 +} diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py new file mode 100644 index 0000000000..5645ab6700 --- /dev/null +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py @@ -0,0 +1,1340 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned From db9ea395e16ccc078847e64204c640a631d0925c Mon Sep 17 00:00:00 2001 From: SPThole Date: Tue, 24 Mar 2026 19:47:27 +0530 Subject: [PATCH 02/10] updt readme --- .../2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store | Bin 8196 -> 0 bytes .../logs/.DS_Store | Bin 6148 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/.DS_Store diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/.DS_Store deleted file mode 100644 index eac9a8f0ea0e0cca587d8e90d848a98e92deabca..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMO-~a+7=EW#+zQAdAEMD@V`E|hB8nj)##jqR(HICN3Vzky?obw&nPzuOMMBcV znGn~3jCJ}i06ZYMPN>2U7>t*U?q4Qw5J#p1$tg6oQNSp$Qvq>yPeTYC$`Efm zzvr=U_@^{UZ032jxT4Sb@4QMF#L39WXOh~N-rsh>I$*V14+=}XR!H2W8t2{Uo@mW+ z9(oPm^==2WRkpJK}WI|1`~EzhgWk0rODI;@P9DcS?;>*vPK4>_ab zV;e)x`uO>gA?L#A*v3Z2>dT(JI9q%it}y;wEtTlU1$3%*vwneIb3H+6Rj-G9YOLq0 zZ8uhwIoRIOd8o_oKHSsY+uhgK*MDT-=&_^6i*}FamL8WQaitSHD>&-v93jwvU}wy@TV6SW17sj_zOk5D=tMm3}<*8 zBmrj&_hS@!DxoV=Y?0%Pt0?evMRditX%r*w;LJY=U5|#-EYFjKuPnGxN$FBmK-Aqjf&%i8P#aSoe7UD3hfWPhI{~ohOT0Z`55J43f=EktJHJ6;i z+%qk?Hs(-BFuwvtaFOLB?#fT~Aiwpi)`k2kcz9I-dNI~q&b*91dRu2tJ=@0az?mbQ zH9$@gzt|OQ4`aQym4|u*yl%N_Y_^(^$-yEpr?IY3ghX8K!iZ-77lIfw3hW^T_Ni+f z5a<7^yZ`>bhgM`R)hJ*T_)`T)da^JnUiRg#{u{Q~H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Tue, 24 Mar 2026 20:24:36 +0530 Subject: [PATCH 03/10] Update README.md --- .../2025-03-24_AWQ_CyclMom_11L_shared/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md index e4b07a2e5f..b0f2cff824 100644 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md +++ b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md @@ -13,8 +13,6 @@ Starting from the community SOTA baseline (thwu1), we introduce four techniques: |-----------|-------------|--------| | **AWQ (Activation-Aware Weight Quantization)** | Scale weight columns by activation importance (alpha=0.5) before quantization. Folds compensation into preceding LayerNorm. Reduces quantization error on high-activation channels. | Quant gap 0.027 → 0.010 bpb | | **Cyclic Muon Momentum** | Triangle wave between 0.85–0.95 with period=50 steps, replacing fixed 0.99 after warmup. Prevents optimizer from settling into sharp minima. | −0.0045 bpb on 1×H100 | -| **ReLU² Activation** | Squared ReLU in MLP layers instead of GELU. Sparser activations, better gradient flow for small models. | −0.005 bpb | -| **11L Shared (10 unique)** | 11 virtual layers using 10 physical weight sets. Last layer reuses block 9. Free depth at zero parameter cost. | −0.003 bpb | ## Results (8×H100) From d7ec40c18f1ec8f95d1411a5e2d24689172b8da8 Mon Sep 17 00:00:00 2001 From: SPThole Date: Wed, 25 Mar 2026 09:55:40 +0530 Subject: [PATCH 04/10] added qk init non record one H100 --- .../README.md | 70 + .../logs.txt | 1572 +++++++++++++++++ .../run.sh | 41 + .../submission.json | 17 + .../train_gpt.py | 1435 +++++++++++++++ 5 files changed, 3135 insertions(+) create mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md create mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt create mode 100755 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh create mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json create mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md new file mode 100644 index 0000000000..5cd5744bb7 --- /dev/null +++ b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md @@ -0,0 +1,70 @@ +# Co-occurrence QK Initialization + +## Score: val_bpb = 1.3525 (1×H100, single seed) + +Trained on 1×H100 80GB in 600 seconds. 15.55MB artifact (int6+zstd). Run on 1×H100 due to compute constraints. Built upon [PR #623](https://github.com/openai/parameter-golf/pull/623). + +## Approach + +Initializes W_Q and W_K in layer 0 from bigram co-occurrence statistics so that the initial attention pattern reflects real token relationships rather than random noise. + +### Mathematical formulation + +**Goal**: At step 0, hidden states h ≈ E[token_id] (just the embedding). We want the attention logit between tokens i and j to approximate their co-occurrence: + +``` +q_i · k_j = (h_i W_Q^T) · (W_K h_j) ≈ C[tok_i, tok_j] +``` + +where C is the 1024×1024 bigram co-occurrence matrix. + +**Step 1 — Build co-occurrence matrix**: Scan 2M training tokens. For each consecutive pair (t_i, t_{i+1}), increment C[t_i, t_{i+1}]. Apply log-transform: C ← log(C + 1) then center (subtract row/column means). This gives a PMI-like matrix. + +**Step 2 — Project into model dimension**: Since W_Q and W_K operate on model_dim (512), not vocab_size (1024), we project C into model space via a fixed random matrix P ∈ R^{1024×512}: + +``` +C_proj = P^T C P ∈ R^{512×512} +``` + +**Step 3 — SVD factorization**: Decompose C_proj = U S V^T. Take top d_head (64) components: + +``` +W_Q ← (U[:, :q_dim] · diag(√S[:q_dim]))^T ∈ R^{q_dim × 512} +W_K ← (V[:k_dim, :] · diag(√S[:k_dim]))^T ∈ R^{k_dim × 512} +``` + +This ensures W_Q^T W_K ≈ C_proj (scaled), so Q·K^T at step 0 reflects co-occurrence. + +**Step 4 — Scale normalization**: Rescale W_Q and W_K to match the norm of the default orthogonal initialization, preventing gradient scale mismatch: + +``` +W_Q ← W_Q · (‖W_Q_orig‖ / ‖W_Q‖) +W_K ← W_K · (‖W_K_orig‖ / ‖W_K‖) +``` + +**Head diversity**: With 8 heads (head_dim=64), SVD components 1–64 go to head 0, 65–128 to head 1, etc. Each head captures a different slice of co-occurrence structure. + +Zero extra parameters — only changes initialization. Co-occurrence computation takes <3s. + +## Hyperparameters + +| Parameter | Value | +|-----------|-------| +| num_layers | 11 (10 unique) | +| model_dim | 512 | +| mlp_activation | ReLU² | +| cooc_init_tokens | 2,000,000 | +| cooc_init_layer | 0 only | +| train_batch_tokens | 524,288 | +| matrix_lr / scalar_lr | 0.025 | +| swa_every | 50, start_frac=0.2 | + +## Key Metrics + +- **val_bpb: 1.3525** (post int6+zstd roundtrip) +- Pre-quant val_bpb: 1.3245 +- Quantization penalty: 0.0280 bpb +- Training: 1,099 steps in 600s (546 ms/step) +- Artifact size: 15,545,987 bytes (15.55MB) +- SWA: averaged 12 checkpoints +- Peak memory: 14,656 MiB diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt new file mode 100644 index 0000000000..bc8128fdd1 --- /dev/null +++ b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt @@ -0,0 +1,1572 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# CO-OCCURRENCE QK INIT +# ----------------------------- + +def build_cooccurrence_matrix(data_pattern: str, vocab_size: int, max_tokens: int = 2_000_000) -> Tensor: + """Build a vocab×vocab bigram co-occurrence matrix from training data. Fast: ~1s for 2M tokens.""" + stream = TokenStream(data_pattern) + tokens = stream.take(min(max_tokens, stream.tokens.numel())) + tokens_np = tokens.numpy().astype(np.int64) + C = np.zeros((vocab_size, vocab_size), dtype=np.float64) + for i in range(len(tokens_np) - 1): + C[tokens_np[i], tokens_np[i + 1]] += 1.0 + # Symmetrize: both (i,j) and (j,i) indicate co-occurrence + C = C + C.T + # Log-PMI-like transform: log(C + 1) then center rows and columns + C = np.log1p(C) + C -= C.mean(axis=1, keepdims=True) + C -= C.mean(axis=0, keepdims=True) + return torch.from_numpy(C).float() + + +def init_qk_from_cooccurrence(model: nn.Module, cooc: Tensor, num_heads: int, num_kv_heads: int, head_dim: int) -> None: + """Initialize W_Q and W_K in layer 0 so that Q·K^T approximates the co-occurrence structure. + SVD the co-occurrence matrix, distribute top components across heads.""" + U, S, Vt = torch.linalg.svd(cooc, full_matrices=False) + # We need num_heads * head_dim components for Q, num_kv_heads * head_dim for K + q_dim = num_heads * head_dim + k_dim = num_kv_heads * head_dim + n_components = max(q_dim, k_dim) + # Scale by sqrt(S) so that Q·K^T ≈ U·S·V^T + S_sqrt = S[:n_components].sqrt() + Q_init = U[:, :q_dim] * S_sqrt[:q_dim].unsqueeze(0) # [vocab, q_dim] + K_init = Vt[:k_dim, :].T * S_sqrt[:k_dim].unsqueeze(0) # [vocab, k_dim] + # But W_Q maps from model_dim (=embedding dim at layer 0) to q_dim + # At init, hidden states ≈ embeddings. We need W_Q such that E @ W_Q ≈ Q_init + # Since E is ~orthogonal at init, W_Q ≈ E^T @ Q_init (pseudo-inverse) + # However, E is vocab×dim, not square. Use E's pseudoinverse. + # Simpler approach: directly set W_Q weight rows from SVD components + # W_Q: [q_dim, model_dim]. At step 0: q = x @ W_Q^T = E[tok] @ W_Q^T + # We want the dot product q_i · k_j to reflect co-occurrence. + # Set W_Q = Q_init^T (projected to model_dim) — but Q_init is [vocab, q_dim] + # This doesn't directly work because W_Q operates on dim, not vocab. + # + # Correct approach: W_Q and W_K are [out_dim, model_dim]. + # q = h @ W_Q^T where h is the hidden state (≈ embedding at layer 0) + # We want: (h_i @ W_Q^T) @ (W_K @ h_j^T) = h_i^T @ W_Q^T @ W_K @ h_j ≈ C[tok_i, tok_j] + # So W_Q^T @ W_K ≈ E_pinv @ C @ E_pinv^T where E_pinv is pseudoinverse of E + # But E is not available at model construction time (it's a parameter). + # Instead, use a simpler signal: set W_Q and W_K from SVD of C projected through + # a random projection (since E is random at init). + # Actually simplest: just scale the existing orthogonal init by the SVD singular values + # to bias different heads toward different co-occurrence patterns. + + # Practical approach: for layer 0, replace W_Q and W_K with SVD-derived weights + # We project the SVD components into model_dim space via random projection + block0 = list(model.blocks.children())[0] if hasattr(model, 'blocks') else None + if block0 is None: + return + model_dim = block0.attn.c_q.weight.shape[1] + # Random projection from vocab_size to model_dim (fixed seed for reproducibility) + rng = torch.Generator() + rng.manual_seed(12345) + P = torch.randn(cooc.shape[0], model_dim, generator=rng) / math.sqrt(model_dim) + # Project co-occurrence into model_dim space: C_proj = P^T @ C @ P [model_dim, model_dim] + C_proj = P.T @ cooc @ P + # SVD of projected matrix + U2, S2, Vt2 = torch.linalg.svd(C_proj, full_matrices=False) + S2_sqrt = S2[:n_components].clamp_min(0).sqrt() + # W_Q rows = U2 columns scaled by sqrt(S), reshaped to [q_dim, model_dim] + # W_K rows = V2 columns scaled by sqrt(S), reshaped to [k_dim, model_dim] + with torch.no_grad(): + q_weight = (U2[:, :q_dim] * S2_sqrt[:q_dim].unsqueeze(0)).T # [q_dim, model_dim] + k_weight = (Vt2[:k_dim, :].T * S2_sqrt[:k_dim].unsqueeze(0)).T # [k_dim, model_dim] + # Normalize to match the scale of orthogonal init + q_scale = block0.attn.c_q.weight.data.norm() / q_weight.norm().clamp_min(1e-8) + k_scale = block0.attn.c_k.weight.data.norm() / k_weight.norm().clamp_min(1e-8) + block0.attn.c_q.weight.data.copy_(q_weight.to(device=block0.attn.c_q.weight.device, dtype=block0.attn.c_q.weight.dtype) * q_scale) + block0.attn.c_k.weight.data.copy_(k_weight.to(device=block0.attn.c_k.weight.device, dtype=block0.attn.c_k.weight.dtype) * k_scale) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Co-occurrence QK init for layer 0 + cooc_init_enabled = bool(int(os.environ.get("COOC_QK_INIT", "1"))) + if cooc_init_enabled and master_process: + log0("cooc_qk_init: building co-occurrence matrix from training data...") + if cooc_init_enabled: + cooc_matrix = build_cooccurrence_matrix(args.train_files, args.vocab_size, max_tokens=2_000_000) + init_qk_from_cooccurrence( + base_model, cooc_matrix, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, + head_dim=args.model_dim // args.num_heads, + ) + if master_process: + log0("cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD") + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 24 18:22:08 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | +| N/A 28C P0 86W / 700W | 527MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 21068 C python3 518MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +cooc_qk_init: building co-occurrence matrix from training data... +cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD +model_params:25517137 +world_size:1 grad_accum_steps:8 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9335 train_time:548ms step_avg:547.60ms +step:2/20000 train_loss:8.8478 train_time:1101ms step_avg:550.55ms +step:3/20000 train_loss:8.4486 train_time:1632ms step_avg:543.84ms +step:4/20000 train_loss:7.7646 train_time:2198ms step_avg:549.51ms +step:5/20000 train_loss:7.1397 train_time:2740ms step_avg:547.97ms +step:6/20000 train_loss:6.7020 train_time:3293ms step_avg:548.79ms +step:7/20000 train_loss:6.2936 train_time:3844ms step_avg:549.21ms +step:8/20000 train_loss:6.1048 train_time:4399ms step_avg:549.83ms +step:9/20000 train_loss:5.9824 train_time:4944ms step_avg:549.28ms +step:10/20000 train_loss:5.8981 train_time:5479ms step_avg:547.89ms +step:25/20000 train_loss:5.1677 train_time:13467ms step_avg:538.66ms +step:50/20000 train_loss:3.8390 train_time:26718ms step_avg:534.35ms +step:75/20000 train_loss:3.4511 train_time:40014ms step_avg:533.52ms +step:100/20000 train_loss:3.2106 train_time:53026ms step_avg:530.26ms +step:100/20000 val_loss:3.2205 val_bpb:1.9073 train_time:53043ms step_avg:530.43ms +step:125/20000 train_loss:3.0733 train_time:66134ms step_avg:529.07ms +step:150/20000 train_loss:2.9838 train_time:79308ms step_avg:528.72ms +step:175/20000 train_loss:2.8149 train_time:92288ms step_avg:527.36ms +step:200/20000 train_loss:2.7380 train_time:106550ms step_avg:532.75ms +step:200/20000 val_loss:2.7826 val_bpb:1.6480 train_time:106574ms step_avg:532.87ms +step:225/20000 train_loss:2.7320 train_time:119604ms step_avg:531.57ms +step:250/20000 train_loss:2.6843 train_time:132736ms step_avg:530.95ms +step:275/20000 train_loss:2.5843 train_time:145991ms step_avg:530.88ms +step:300/20000 train_loss:2.5525 train_time:159423ms step_avg:531.41ms +step:300/20000 val_loss:2.5822 val_bpb:1.5293 train_time:159451ms step_avg:531.50ms +step:325/20000 train_loss:2.5580 train_time:172610ms step_avg:531.11ms +step:350/20000 train_loss:2.5409 train_time:186013ms step_avg:531.46ms +step:375/20000 train_loss:2.5416 train_time:199583ms step_avg:532.22ms +step:400/20000 train_loss:2.3160 train_time:218680ms step_avg:546.70ms +step:400/20000 val_loss:2.4794 val_bpb:1.4685 train_time:218716ms step_avg:546.79ms +step:425/20000 train_loss:2.4497 train_time:232102ms step_avg:546.12ms +step:450/20000 train_loss:2.4266 train_time:246170ms step_avg:547.04ms +step:475/20000 train_loss:2.4498 train_time:259795ms step_avg:546.94ms +swa:start step:500 +step:500/20000 train_loss:2.4200 train_time:273698ms step_avg:547.40ms +step:500/20000 val_loss:2.4112 val_bpb:1.4281 train_time:274052ms step_avg:548.10ms +step:525/20000 train_loss:2.4750 train_time:287610ms step_avg:547.83ms +step:550/20000 train_loss:2.2941 train_time:301497ms step_avg:548.18ms +step:575/20000 train_loss:2.3513 train_time:316217ms step_avg:549.94ms +step:600/20000 train_loss:2.4148 train_time:329542ms step_avg:549.24ms +step:600/20000 val_loss:2.3653 val_bpb:1.4008 train_time:329773ms step_avg:549.62ms +step:625/20000 train_loss:2.4218 train_time:343007ms step_avg:548.81ms +step:650/20000 train_loss:2.2302 train_time:356389ms step_avg:548.29ms +step:675/20000 train_loss:2.2239 train_time:370042ms step_avg:548.21ms +step:700/20000 train_loss:2.3745 train_time:383937ms step_avg:548.48ms +step:700/20000 val_loss:2.3266 val_bpb:1.3780 train_time:384341ms step_avg:549.06ms +step:725/20000 train_loss:2.3738 train_time:397694ms step_avg:548.54ms +step:750/20000 train_loss:2.2705 train_time:411295ms step_avg:548.39ms +step:775/20000 train_loss:2.3580 train_time:426595ms step_avg:550.44ms +step:800/20000 train_loss:2.2639 train_time:439822ms step_avg:549.78ms +step:800/20000 val_loss:2.2967 val_bpb:1.3603 train_time:440124ms step_avg:550.16ms +step:825/20000 train_loss:2.3659 train_time:453451ms step_avg:549.64ms +step:850/20000 train_loss:2.3646 train_time:466828ms step_avg:549.21ms +step:875/20000 train_loss:2.2751 train_time:480541ms step_avg:549.19ms +step:900/20000 train_loss:2.2894 train_time:494273ms step_avg:549.19ms +step:900/20000 val_loss:2.2713 val_bpb:1.3452 train_time:494381ms step_avg:549.31ms +step:925/20000 train_loss:2.2073 train_time:507739ms step_avg:548.91ms +step:950/20000 train_loss:2.2970 train_time:520999ms step_avg:548.42ms +step:975/20000 train_loss:2.2168 train_time:534601ms step_avg:548.31ms +step:1000/20000 train_loss:2.2837 train_time:547685ms step_avg:547.68ms +step:1000/20000 val_loss:2.2493 val_bpb:1.3322 train_time:547893ms step_avg:547.89ms +step:1025/20000 train_loss:2.3489 train_time:561110ms step_avg:547.42ms +step:1050/20000 train_loss:2.3059 train_time:574396ms step_avg:547.04ms +step:1075/20000 train_loss:2.1857 train_time:587500ms step_avg:546.51ms +step:1099/20000 val_loss:2.2364 val_bpb:1.3245 train_time:600138ms step_avg:546.08ms +stopping_early: wallclock_cap train_time:600138ms step:1099/20000 +peak memory allocated: 14656 MiB reserved: 14872 MiB +swa:applying averaged 12 checkpoints +Serialized model: 98437419 bytes +Code size: 64061 bytes +Total submission size: 98501480 bytes +awq:calibrating alpha=0.5 +awq:scaled 61 layers +Serialized model int6+zstd: 15481926 bytes +Total submission size int8+zlib: 15545987 bytes +awq:unscaled 61 layers after dequant +final_eval_mode:standard +final_int8_zlib_roundtrip val_loss:2.2837 val_bpb:1.3525 eval_time:18099ms +final_int8_zlib_roundtrip_exact val_loss:2.28368322 val_bpb:1.35252584 diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh new file mode 100755 index 0000000000..c07b68d98c --- /dev/null +++ b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# ============================================================ +# Exp29: Co-occurrence QK Init — from Exp18 +# Builds bigram co-occurrence matrix from first 2M training tokens (~1s), +# SVDs it, initializes W_Q and W_K in layer 0 so initial attention +# patterns reflect token co-occurrence structure. Zero extra params. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp29_cooccurrence-qk-init_from-exp18" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export COOC_QK_INIT=1 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +echo "Co-occurrence QK init: layer 0, 2M tokens" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json new file mode 100644 index 0000000000..da3ea0fdd7 --- /dev/null +++ b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Sidhant Thole", + "github_id": "SPThole", + "name": "Co-occurrence QK Initialization", + "blurb": "Initializes W_Q and W_K in layer 0 from bigram co-occurrence SVD, giving attention a distributional head start. Scans 2M tokens at startup (<3s), applies SVD, projects into model space. Zero extra parameters. Built on exp18 baseline. Run on 1×H100. Based on PR #623.", + "date": "2026-03-24T19:15:00Z", + "val_loss": 2.28368322, + "val_bpb": 1.35252584, + "pre_quant_val_loss": 2.2364, + "pre_quant_val_bpb": 1.3245, + "step_stop": 1099, + "wallclock_seconds": 600.138, + "bytes_total": 15545987, + "bytes_code": 64061, + "gpu_config": "1×H100 80GB", + "base_pr": "https://github.com/openai/parameter-golf/pull/623" +} diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py new file mode 100644 index 0000000000..d60a73c02d --- /dev/null +++ b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py @@ -0,0 +1,1435 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# CO-OCCURRENCE QK INIT +# ----------------------------- + +def build_cooccurrence_matrix(data_pattern: str, vocab_size: int, max_tokens: int = 2_000_000) -> Tensor: + """Build a vocab×vocab bigram co-occurrence matrix from training data. Fast: ~1s for 2M tokens.""" + stream = TokenStream(data_pattern) + tokens = stream.take(min(max_tokens, stream.tokens.numel())) + tokens_np = tokens.numpy().astype(np.int64) + C = np.zeros((vocab_size, vocab_size), dtype=np.float64) + for i in range(len(tokens_np) - 1): + C[tokens_np[i], tokens_np[i + 1]] += 1.0 + # Symmetrize: both (i,j) and (j,i) indicate co-occurrence + C = C + C.T + # Log-PMI-like transform: log(C + 1) then center rows and columns + C = np.log1p(C) + C -= C.mean(axis=1, keepdims=True) + C -= C.mean(axis=0, keepdims=True) + return torch.from_numpy(C).float() + + +def init_qk_from_cooccurrence(model: nn.Module, cooc: Tensor, num_heads: int, num_kv_heads: int, head_dim: int) -> None: + """Initialize W_Q and W_K in layer 0 so that Q·K^T approximates the co-occurrence structure. + SVD the co-occurrence matrix, distribute top components across heads.""" + U, S, Vt = torch.linalg.svd(cooc, full_matrices=False) + # We need num_heads * head_dim components for Q, num_kv_heads * head_dim for K + q_dim = num_heads * head_dim + k_dim = num_kv_heads * head_dim + n_components = max(q_dim, k_dim) + # Scale by sqrt(S) so that Q·K^T ≈ U·S·V^T + S_sqrt = S[:n_components].sqrt() + Q_init = U[:, :q_dim] * S_sqrt[:q_dim].unsqueeze(0) # [vocab, q_dim] + K_init = Vt[:k_dim, :].T * S_sqrt[:k_dim].unsqueeze(0) # [vocab, k_dim] + # But W_Q maps from model_dim (=embedding dim at layer 0) to q_dim + # At init, hidden states ≈ embeddings. We need W_Q such that E @ W_Q ≈ Q_init + # Since E is ~orthogonal at init, W_Q ≈ E^T @ Q_init (pseudo-inverse) + # However, E is vocab×dim, not square. Use E's pseudoinverse. + # Simpler approach: directly set W_Q weight rows from SVD components + # W_Q: [q_dim, model_dim]. At step 0: q = x @ W_Q^T = E[tok] @ W_Q^T + # We want the dot product q_i · k_j to reflect co-occurrence. + # Set W_Q = Q_init^T (projected to model_dim) — but Q_init is [vocab, q_dim] + # This doesn't directly work because W_Q operates on dim, not vocab. + # + # Correct approach: W_Q and W_K are [out_dim, model_dim]. + # q = h @ W_Q^T where h is the hidden state (≈ embedding at layer 0) + # We want: (h_i @ W_Q^T) @ (W_K @ h_j^T) = h_i^T @ W_Q^T @ W_K @ h_j ≈ C[tok_i, tok_j] + # So W_Q^T @ W_K ≈ E_pinv @ C @ E_pinv^T where E_pinv is pseudoinverse of E + # But E is not available at model construction time (it's a parameter). + # Instead, use a simpler signal: set W_Q and W_K from SVD of C projected through + # a random projection (since E is random at init). + # Actually simplest: just scale the existing orthogonal init by the SVD singular values + # to bias different heads toward different co-occurrence patterns. + + # Practical approach: for layer 0, replace W_Q and W_K with SVD-derived weights + # We project the SVD components into model_dim space via random projection + block0 = list(model.blocks.children())[0] if hasattr(model, 'blocks') else None + if block0 is None: + return + model_dim = block0.attn.c_q.weight.shape[1] + # Random projection from vocab_size to model_dim (fixed seed for reproducibility) + rng = torch.Generator() + rng.manual_seed(12345) + P = torch.randn(cooc.shape[0], model_dim, generator=rng) / math.sqrt(model_dim) + # Project co-occurrence into model_dim space: C_proj = P^T @ C @ P [model_dim, model_dim] + C_proj = P.T @ cooc @ P + # SVD of projected matrix + U2, S2, Vt2 = torch.linalg.svd(C_proj, full_matrices=False) + S2_sqrt = S2[:n_components].clamp_min(0).sqrt() + # W_Q rows = U2 columns scaled by sqrt(S), reshaped to [q_dim, model_dim] + # W_K rows = V2 columns scaled by sqrt(S), reshaped to [k_dim, model_dim] + with torch.no_grad(): + q_weight = (U2[:, :q_dim] * S2_sqrt[:q_dim].unsqueeze(0)).T # [q_dim, model_dim] + k_weight = (Vt2[:k_dim, :].T * S2_sqrt[:k_dim].unsqueeze(0)).T # [k_dim, model_dim] + # Normalize to match the scale of orthogonal init + q_scale = block0.attn.c_q.weight.data.norm() / q_weight.norm().clamp_min(1e-8) + k_scale = block0.attn.c_k.weight.data.norm() / k_weight.norm().clamp_min(1e-8) + block0.attn.c_q.weight.data.copy_(q_weight.to(device=block0.attn.c_q.weight.device, dtype=block0.attn.c_q.weight.dtype) * q_scale) + block0.attn.c_k.weight.data.copy_(k_weight.to(device=block0.attn.c_k.weight.device, dtype=block0.attn.c_k.weight.dtype) * k_scale) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Co-occurrence QK init for layer 0 + cooc_init_enabled = bool(int(os.environ.get("COOC_QK_INIT", "1"))) + if cooc_init_enabled and master_process: + log0("cooc_qk_init: building co-occurrence matrix from training data...") + if cooc_init_enabled: + cooc_matrix = build_cooccurrence_matrix(args.train_files, args.vocab_size, max_tokens=2_000_000) + init_qk_from_cooccurrence( + base_model, cooc_matrix, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, + head_dim=args.model_dim // args.num_heads, + ) + if master_process: + log0("cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD") + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned From 3019166619ecf05286820a336917f94da324ca40 Mon Sep 17 00:00:00 2001 From: SPThole Date: Wed, 25 Mar 2026 09:59:41 +0530 Subject: [PATCH 05/10] removing non record --- .../README.md | 70 - .../logs.txt | 1572 ----------------- .../run.sh | 41 - .../submission.json | 17 - .../train_gpt.py | 1435 --------------- 5 files changed, 3135 deletions(-) delete mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md delete mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt delete mode 100755 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh delete mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json delete mode 100644 records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md deleted file mode 100644 index 5cd5744bb7..0000000000 --- a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/README.md +++ /dev/null @@ -1,70 +0,0 @@ -# Co-occurrence QK Initialization - -## Score: val_bpb = 1.3525 (1×H100, single seed) - -Trained on 1×H100 80GB in 600 seconds. 15.55MB artifact (int6+zstd). Run on 1×H100 due to compute constraints. Built upon [PR #623](https://github.com/openai/parameter-golf/pull/623). - -## Approach - -Initializes W_Q and W_K in layer 0 from bigram co-occurrence statistics so that the initial attention pattern reflects real token relationships rather than random noise. - -### Mathematical formulation - -**Goal**: At step 0, hidden states h ≈ E[token_id] (just the embedding). We want the attention logit between tokens i and j to approximate their co-occurrence: - -``` -q_i · k_j = (h_i W_Q^T) · (W_K h_j) ≈ C[tok_i, tok_j] -``` - -where C is the 1024×1024 bigram co-occurrence matrix. - -**Step 1 — Build co-occurrence matrix**: Scan 2M training tokens. For each consecutive pair (t_i, t_{i+1}), increment C[t_i, t_{i+1}]. Apply log-transform: C ← log(C + 1) then center (subtract row/column means). This gives a PMI-like matrix. - -**Step 2 — Project into model dimension**: Since W_Q and W_K operate on model_dim (512), not vocab_size (1024), we project C into model space via a fixed random matrix P ∈ R^{1024×512}: - -``` -C_proj = P^T C P ∈ R^{512×512} -``` - -**Step 3 — SVD factorization**: Decompose C_proj = U S V^T. Take top d_head (64) components: - -``` -W_Q ← (U[:, :q_dim] · diag(√S[:q_dim]))^T ∈ R^{q_dim × 512} -W_K ← (V[:k_dim, :] · diag(√S[:k_dim]))^T ∈ R^{k_dim × 512} -``` - -This ensures W_Q^T W_K ≈ C_proj (scaled), so Q·K^T at step 0 reflects co-occurrence. - -**Step 4 — Scale normalization**: Rescale W_Q and W_K to match the norm of the default orthogonal initialization, preventing gradient scale mismatch: - -``` -W_Q ← W_Q · (‖W_Q_orig‖ / ‖W_Q‖) -W_K ← W_K · (‖W_K_orig‖ / ‖W_K‖) -``` - -**Head diversity**: With 8 heads (head_dim=64), SVD components 1–64 go to head 0, 65–128 to head 1, etc. Each head captures a different slice of co-occurrence structure. - -Zero extra parameters — only changes initialization. Co-occurrence computation takes <3s. - -## Hyperparameters - -| Parameter | Value | -|-----------|-------| -| num_layers | 11 (10 unique) | -| model_dim | 512 | -| mlp_activation | ReLU² | -| cooc_init_tokens | 2,000,000 | -| cooc_init_layer | 0 only | -| train_batch_tokens | 524,288 | -| matrix_lr / scalar_lr | 0.025 | -| swa_every | 50, start_frac=0.2 | - -## Key Metrics - -- **val_bpb: 1.3525** (post int6+zstd roundtrip) -- Pre-quant val_bpb: 1.3245 -- Quantization penalty: 0.0280 bpb -- Training: 1,099 steps in 600s (546 ms/step) -- Artifact size: 15,545,987 bytes (15.55MB) -- SWA: averaged 12 checkpoints -- Peak memory: 14,656 MiB diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt deleted file mode 100644 index bc8128fdd1..0000000000 --- a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/logs.txt +++ /dev/null @@ -1,1572 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# CO-OCCURRENCE QK INIT -# ----------------------------- - -def build_cooccurrence_matrix(data_pattern: str, vocab_size: int, max_tokens: int = 2_000_000) -> Tensor: - """Build a vocab×vocab bigram co-occurrence matrix from training data. Fast: ~1s for 2M tokens.""" - stream = TokenStream(data_pattern) - tokens = stream.take(min(max_tokens, stream.tokens.numel())) - tokens_np = tokens.numpy().astype(np.int64) - C = np.zeros((vocab_size, vocab_size), dtype=np.float64) - for i in range(len(tokens_np) - 1): - C[tokens_np[i], tokens_np[i + 1]] += 1.0 - # Symmetrize: both (i,j) and (j,i) indicate co-occurrence - C = C + C.T - # Log-PMI-like transform: log(C + 1) then center rows and columns - C = np.log1p(C) - C -= C.mean(axis=1, keepdims=True) - C -= C.mean(axis=0, keepdims=True) - return torch.from_numpy(C).float() - - -def init_qk_from_cooccurrence(model: nn.Module, cooc: Tensor, num_heads: int, num_kv_heads: int, head_dim: int) -> None: - """Initialize W_Q and W_K in layer 0 so that Q·K^T approximates the co-occurrence structure. - SVD the co-occurrence matrix, distribute top components across heads.""" - U, S, Vt = torch.linalg.svd(cooc, full_matrices=False) - # We need num_heads * head_dim components for Q, num_kv_heads * head_dim for K - q_dim = num_heads * head_dim - k_dim = num_kv_heads * head_dim - n_components = max(q_dim, k_dim) - # Scale by sqrt(S) so that Q·K^T ≈ U·S·V^T - S_sqrt = S[:n_components].sqrt() - Q_init = U[:, :q_dim] * S_sqrt[:q_dim].unsqueeze(0) # [vocab, q_dim] - K_init = Vt[:k_dim, :].T * S_sqrt[:k_dim].unsqueeze(0) # [vocab, k_dim] - # But W_Q maps from model_dim (=embedding dim at layer 0) to q_dim - # At init, hidden states ≈ embeddings. We need W_Q such that E @ W_Q ≈ Q_init - # Since E is ~orthogonal at init, W_Q ≈ E^T @ Q_init (pseudo-inverse) - # However, E is vocab×dim, not square. Use E's pseudoinverse. - # Simpler approach: directly set W_Q weight rows from SVD components - # W_Q: [q_dim, model_dim]. At step 0: q = x @ W_Q^T = E[tok] @ W_Q^T - # We want the dot product q_i · k_j to reflect co-occurrence. - # Set W_Q = Q_init^T (projected to model_dim) — but Q_init is [vocab, q_dim] - # This doesn't directly work because W_Q operates on dim, not vocab. - # - # Correct approach: W_Q and W_K are [out_dim, model_dim]. - # q = h @ W_Q^T where h is the hidden state (≈ embedding at layer 0) - # We want: (h_i @ W_Q^T) @ (W_K @ h_j^T) = h_i^T @ W_Q^T @ W_K @ h_j ≈ C[tok_i, tok_j] - # So W_Q^T @ W_K ≈ E_pinv @ C @ E_pinv^T where E_pinv is pseudoinverse of E - # But E is not available at model construction time (it's a parameter). - # Instead, use a simpler signal: set W_Q and W_K from SVD of C projected through - # a random projection (since E is random at init). - # Actually simplest: just scale the existing orthogonal init by the SVD singular values - # to bias different heads toward different co-occurrence patterns. - - # Practical approach: for layer 0, replace W_Q and W_K with SVD-derived weights - # We project the SVD components into model_dim space via random projection - block0 = list(model.blocks.children())[0] if hasattr(model, 'blocks') else None - if block0 is None: - return - model_dim = block0.attn.c_q.weight.shape[1] - # Random projection from vocab_size to model_dim (fixed seed for reproducibility) - rng = torch.Generator() - rng.manual_seed(12345) - P = torch.randn(cooc.shape[0], model_dim, generator=rng) / math.sqrt(model_dim) - # Project co-occurrence into model_dim space: C_proj = P^T @ C @ P [model_dim, model_dim] - C_proj = P.T @ cooc @ P - # SVD of projected matrix - U2, S2, Vt2 = torch.linalg.svd(C_proj, full_matrices=False) - S2_sqrt = S2[:n_components].clamp_min(0).sqrt() - # W_Q rows = U2 columns scaled by sqrt(S), reshaped to [q_dim, model_dim] - # W_K rows = V2 columns scaled by sqrt(S), reshaped to [k_dim, model_dim] - with torch.no_grad(): - q_weight = (U2[:, :q_dim] * S2_sqrt[:q_dim].unsqueeze(0)).T # [q_dim, model_dim] - k_weight = (Vt2[:k_dim, :].T * S2_sqrt[:k_dim].unsqueeze(0)).T # [k_dim, model_dim] - # Normalize to match the scale of orthogonal init - q_scale = block0.attn.c_q.weight.data.norm() / q_weight.norm().clamp_min(1e-8) - k_scale = block0.attn.c_k.weight.data.norm() / k_weight.norm().clamp_min(1e-8) - block0.attn.c_q.weight.data.copy_(q_weight.to(device=block0.attn.c_q.weight.device, dtype=block0.attn.c_q.weight.dtype) * q_scale) - block0.attn.c_k.weight.data.copy_(k_weight.to(device=block0.attn.c_k.weight.device, dtype=block0.attn.c_k.weight.dtype) * k_scale) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - - # Co-occurrence QK init for layer 0 - cooc_init_enabled = bool(int(os.environ.get("COOC_QK_INIT", "1"))) - if cooc_init_enabled and master_process: - log0("cooc_qk_init: building co-occurrence matrix from training data...") - if cooc_init_enabled: - cooc_matrix = build_cooccurrence_matrix(args.train_files, args.vocab_size, max_tokens=2_000_000) - init_qk_from_cooccurrence( - base_model, cooc_matrix, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, - head_dim=args.model_dim // args.num_heads, - ) - if master_process: - log0("cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD") - - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 24 18:22:08 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 28C P0 86W / 700W | 527MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 21068 C python3 518MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -cooc_qk_init: building co-occurrence matrix from training data... -cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD -model_params:25517137 -world_size:1 grad_accum_steps:8 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9335 train_time:548ms step_avg:547.60ms -step:2/20000 train_loss:8.8478 train_time:1101ms step_avg:550.55ms -step:3/20000 train_loss:8.4486 train_time:1632ms step_avg:543.84ms -step:4/20000 train_loss:7.7646 train_time:2198ms step_avg:549.51ms -step:5/20000 train_loss:7.1397 train_time:2740ms step_avg:547.97ms -step:6/20000 train_loss:6.7020 train_time:3293ms step_avg:548.79ms -step:7/20000 train_loss:6.2936 train_time:3844ms step_avg:549.21ms -step:8/20000 train_loss:6.1048 train_time:4399ms step_avg:549.83ms -step:9/20000 train_loss:5.9824 train_time:4944ms step_avg:549.28ms -step:10/20000 train_loss:5.8981 train_time:5479ms step_avg:547.89ms -step:25/20000 train_loss:5.1677 train_time:13467ms step_avg:538.66ms -step:50/20000 train_loss:3.8390 train_time:26718ms step_avg:534.35ms -step:75/20000 train_loss:3.4511 train_time:40014ms step_avg:533.52ms -step:100/20000 train_loss:3.2106 train_time:53026ms step_avg:530.26ms -step:100/20000 val_loss:3.2205 val_bpb:1.9073 train_time:53043ms step_avg:530.43ms -step:125/20000 train_loss:3.0733 train_time:66134ms step_avg:529.07ms -step:150/20000 train_loss:2.9838 train_time:79308ms step_avg:528.72ms -step:175/20000 train_loss:2.8149 train_time:92288ms step_avg:527.36ms -step:200/20000 train_loss:2.7380 train_time:106550ms step_avg:532.75ms -step:200/20000 val_loss:2.7826 val_bpb:1.6480 train_time:106574ms step_avg:532.87ms -step:225/20000 train_loss:2.7320 train_time:119604ms step_avg:531.57ms -step:250/20000 train_loss:2.6843 train_time:132736ms step_avg:530.95ms -step:275/20000 train_loss:2.5843 train_time:145991ms step_avg:530.88ms -step:300/20000 train_loss:2.5525 train_time:159423ms step_avg:531.41ms -step:300/20000 val_loss:2.5822 val_bpb:1.5293 train_time:159451ms step_avg:531.50ms -step:325/20000 train_loss:2.5580 train_time:172610ms step_avg:531.11ms -step:350/20000 train_loss:2.5409 train_time:186013ms step_avg:531.46ms -step:375/20000 train_loss:2.5416 train_time:199583ms step_avg:532.22ms -step:400/20000 train_loss:2.3160 train_time:218680ms step_avg:546.70ms -step:400/20000 val_loss:2.4794 val_bpb:1.4685 train_time:218716ms step_avg:546.79ms -step:425/20000 train_loss:2.4497 train_time:232102ms step_avg:546.12ms -step:450/20000 train_loss:2.4266 train_time:246170ms step_avg:547.04ms -step:475/20000 train_loss:2.4498 train_time:259795ms step_avg:546.94ms -swa:start step:500 -step:500/20000 train_loss:2.4200 train_time:273698ms step_avg:547.40ms -step:500/20000 val_loss:2.4112 val_bpb:1.4281 train_time:274052ms step_avg:548.10ms -step:525/20000 train_loss:2.4750 train_time:287610ms step_avg:547.83ms -step:550/20000 train_loss:2.2941 train_time:301497ms step_avg:548.18ms -step:575/20000 train_loss:2.3513 train_time:316217ms step_avg:549.94ms -step:600/20000 train_loss:2.4148 train_time:329542ms step_avg:549.24ms -step:600/20000 val_loss:2.3653 val_bpb:1.4008 train_time:329773ms step_avg:549.62ms -step:625/20000 train_loss:2.4218 train_time:343007ms step_avg:548.81ms -step:650/20000 train_loss:2.2302 train_time:356389ms step_avg:548.29ms -step:675/20000 train_loss:2.2239 train_time:370042ms step_avg:548.21ms -step:700/20000 train_loss:2.3745 train_time:383937ms step_avg:548.48ms -step:700/20000 val_loss:2.3266 val_bpb:1.3780 train_time:384341ms step_avg:549.06ms -step:725/20000 train_loss:2.3738 train_time:397694ms step_avg:548.54ms -step:750/20000 train_loss:2.2705 train_time:411295ms step_avg:548.39ms -step:775/20000 train_loss:2.3580 train_time:426595ms step_avg:550.44ms -step:800/20000 train_loss:2.2639 train_time:439822ms step_avg:549.78ms -step:800/20000 val_loss:2.2967 val_bpb:1.3603 train_time:440124ms step_avg:550.16ms -step:825/20000 train_loss:2.3659 train_time:453451ms step_avg:549.64ms -step:850/20000 train_loss:2.3646 train_time:466828ms step_avg:549.21ms -step:875/20000 train_loss:2.2751 train_time:480541ms step_avg:549.19ms -step:900/20000 train_loss:2.2894 train_time:494273ms step_avg:549.19ms -step:900/20000 val_loss:2.2713 val_bpb:1.3452 train_time:494381ms step_avg:549.31ms -step:925/20000 train_loss:2.2073 train_time:507739ms step_avg:548.91ms -step:950/20000 train_loss:2.2970 train_time:520999ms step_avg:548.42ms -step:975/20000 train_loss:2.2168 train_time:534601ms step_avg:548.31ms -step:1000/20000 train_loss:2.2837 train_time:547685ms step_avg:547.68ms -step:1000/20000 val_loss:2.2493 val_bpb:1.3322 train_time:547893ms step_avg:547.89ms -step:1025/20000 train_loss:2.3489 train_time:561110ms step_avg:547.42ms -step:1050/20000 train_loss:2.3059 train_time:574396ms step_avg:547.04ms -step:1075/20000 train_loss:2.1857 train_time:587500ms step_avg:546.51ms -step:1099/20000 val_loss:2.2364 val_bpb:1.3245 train_time:600138ms step_avg:546.08ms -stopping_early: wallclock_cap train_time:600138ms step:1099/20000 -peak memory allocated: 14656 MiB reserved: 14872 MiB -swa:applying averaged 12 checkpoints -Serialized model: 98437419 bytes -Code size: 64061 bytes -Total submission size: 98501480 bytes -awq:calibrating alpha=0.5 -awq:scaled 61 layers -Serialized model int6+zstd: 15481926 bytes -Total submission size int8+zlib: 15545987 bytes -awq:unscaled 61 layers after dequant -final_eval_mode:standard -final_int8_zlib_roundtrip val_loss:2.2837 val_bpb:1.3525 eval_time:18099ms -final_int8_zlib_roundtrip_exact val_loss:2.28368322 val_bpb:1.35252584 diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh deleted file mode 100755 index c07b68d98c..0000000000 --- a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/run.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp29: Co-occurrence QK Init — from Exp18 -# Builds bigram co-occurrence matrix from first 2M training tokens (~1s), -# SVDs it, initializes W_Q and W_K in layer 0 so initial attention -# patterns reflect token co-occurrence structure. Zero extra params. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp29_cooccurrence-qk-init_from-exp18" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export COOC_QK_INIT=1 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -echo "Co-occurrence QK init: layer 0, 2M tokens" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json deleted file mode 100644 index da3ea0fdd7..0000000000 --- a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/submission.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "author": "Sidhant Thole", - "github_id": "SPThole", - "name": "Co-occurrence QK Initialization", - "blurb": "Initializes W_Q and W_K in layer 0 from bigram co-occurrence SVD, giving attention a distributional head start. Scans 2M tokens at startup (<3s), applies SVD, projects into model space. Zero extra parameters. Built on exp18 baseline. Run on 1×H100. Based on PR #623.", - "date": "2026-03-24T19:15:00Z", - "val_loss": 2.28368322, - "val_bpb": 1.35252584, - "pre_quant_val_loss": 2.2364, - "pre_quant_val_bpb": 1.3245, - "step_stop": 1099, - "wallclock_seconds": 600.138, - "bytes_total": 15545987, - "bytes_code": 64061, - "gpu_config": "1×H100 80GB", - "base_pr": "https://github.com/openai/parameter-golf/pull/623" -} diff --git a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py b/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py deleted file mode 100644 index d60a73c02d..0000000000 --- a/records/track_non_record_16mb/2026_03_24_cooccurrence-qk-init_from-exp18/train_gpt.py +++ /dev/null @@ -1,1435 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# CO-OCCURRENCE QK INIT -# ----------------------------- - -def build_cooccurrence_matrix(data_pattern: str, vocab_size: int, max_tokens: int = 2_000_000) -> Tensor: - """Build a vocab×vocab bigram co-occurrence matrix from training data. Fast: ~1s for 2M tokens.""" - stream = TokenStream(data_pattern) - tokens = stream.take(min(max_tokens, stream.tokens.numel())) - tokens_np = tokens.numpy().astype(np.int64) - C = np.zeros((vocab_size, vocab_size), dtype=np.float64) - for i in range(len(tokens_np) - 1): - C[tokens_np[i], tokens_np[i + 1]] += 1.0 - # Symmetrize: both (i,j) and (j,i) indicate co-occurrence - C = C + C.T - # Log-PMI-like transform: log(C + 1) then center rows and columns - C = np.log1p(C) - C -= C.mean(axis=1, keepdims=True) - C -= C.mean(axis=0, keepdims=True) - return torch.from_numpy(C).float() - - -def init_qk_from_cooccurrence(model: nn.Module, cooc: Tensor, num_heads: int, num_kv_heads: int, head_dim: int) -> None: - """Initialize W_Q and W_K in layer 0 so that Q·K^T approximates the co-occurrence structure. - SVD the co-occurrence matrix, distribute top components across heads.""" - U, S, Vt = torch.linalg.svd(cooc, full_matrices=False) - # We need num_heads * head_dim components for Q, num_kv_heads * head_dim for K - q_dim = num_heads * head_dim - k_dim = num_kv_heads * head_dim - n_components = max(q_dim, k_dim) - # Scale by sqrt(S) so that Q·K^T ≈ U·S·V^T - S_sqrt = S[:n_components].sqrt() - Q_init = U[:, :q_dim] * S_sqrt[:q_dim].unsqueeze(0) # [vocab, q_dim] - K_init = Vt[:k_dim, :].T * S_sqrt[:k_dim].unsqueeze(0) # [vocab, k_dim] - # But W_Q maps from model_dim (=embedding dim at layer 0) to q_dim - # At init, hidden states ≈ embeddings. We need W_Q such that E @ W_Q ≈ Q_init - # Since E is ~orthogonal at init, W_Q ≈ E^T @ Q_init (pseudo-inverse) - # However, E is vocab×dim, not square. Use E's pseudoinverse. - # Simpler approach: directly set W_Q weight rows from SVD components - # W_Q: [q_dim, model_dim]. At step 0: q = x @ W_Q^T = E[tok] @ W_Q^T - # We want the dot product q_i · k_j to reflect co-occurrence. - # Set W_Q = Q_init^T (projected to model_dim) — but Q_init is [vocab, q_dim] - # This doesn't directly work because W_Q operates on dim, not vocab. - # - # Correct approach: W_Q and W_K are [out_dim, model_dim]. - # q = h @ W_Q^T where h is the hidden state (≈ embedding at layer 0) - # We want: (h_i @ W_Q^T) @ (W_K @ h_j^T) = h_i^T @ W_Q^T @ W_K @ h_j ≈ C[tok_i, tok_j] - # So W_Q^T @ W_K ≈ E_pinv @ C @ E_pinv^T where E_pinv is pseudoinverse of E - # But E is not available at model construction time (it's a parameter). - # Instead, use a simpler signal: set W_Q and W_K from SVD of C projected through - # a random projection (since E is random at init). - # Actually simplest: just scale the existing orthogonal init by the SVD singular values - # to bias different heads toward different co-occurrence patterns. - - # Practical approach: for layer 0, replace W_Q and W_K with SVD-derived weights - # We project the SVD components into model_dim space via random projection - block0 = list(model.blocks.children())[0] if hasattr(model, 'blocks') else None - if block0 is None: - return - model_dim = block0.attn.c_q.weight.shape[1] - # Random projection from vocab_size to model_dim (fixed seed for reproducibility) - rng = torch.Generator() - rng.manual_seed(12345) - P = torch.randn(cooc.shape[0], model_dim, generator=rng) / math.sqrt(model_dim) - # Project co-occurrence into model_dim space: C_proj = P^T @ C @ P [model_dim, model_dim] - C_proj = P.T @ cooc @ P - # SVD of projected matrix - U2, S2, Vt2 = torch.linalg.svd(C_proj, full_matrices=False) - S2_sqrt = S2[:n_components].clamp_min(0).sqrt() - # W_Q rows = U2 columns scaled by sqrt(S), reshaped to [q_dim, model_dim] - # W_K rows = V2 columns scaled by sqrt(S), reshaped to [k_dim, model_dim] - with torch.no_grad(): - q_weight = (U2[:, :q_dim] * S2_sqrt[:q_dim].unsqueeze(0)).T # [q_dim, model_dim] - k_weight = (Vt2[:k_dim, :].T * S2_sqrt[:k_dim].unsqueeze(0)).T # [k_dim, model_dim] - # Normalize to match the scale of orthogonal init - q_scale = block0.attn.c_q.weight.data.norm() / q_weight.norm().clamp_min(1e-8) - k_scale = block0.attn.c_k.weight.data.norm() / k_weight.norm().clamp_min(1e-8) - block0.attn.c_q.weight.data.copy_(q_weight.to(device=block0.attn.c_q.weight.device, dtype=block0.attn.c_q.weight.dtype) * q_scale) - block0.attn.c_k.weight.data.copy_(k_weight.to(device=block0.attn.c_k.weight.device, dtype=block0.attn.c_k.weight.dtype) * k_scale) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - - # Co-occurrence QK init for layer 0 - cooc_init_enabled = bool(int(os.environ.get("COOC_QK_INIT", "1"))) - if cooc_init_enabled and master_process: - log0("cooc_qk_init: building co-occurrence matrix from training data...") - if cooc_init_enabled: - cooc_matrix = build_cooccurrence_matrix(args.train_files, args.vocab_size, max_tokens=2_000_000) - init_qk_from_cooccurrence( - base_model, cooc_matrix, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, - head_dim=args.model_dim // args.num_heads, - ) - if master_process: - log0("cooc_qk_init: done, W_Q and W_K in layer 0 initialized from co-occurrence SVD") - - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned From de42f2ba75466c68831b408c9d686b256b74a792 Mon Sep 17 00:00:00 2001 From: SPThole Date: Thu, 26 Mar 2026 12:30:14 +0530 Subject: [PATCH 06/10] adding exps --- .../Inference.ipynb | 281 +++ .../exp00_baseline-rerun_exp27/README.md | 22 + .../exp00_baseline-rerun_exp27/logs.txt | 0 .../phase3/exp00_baseline-rerun_exp27/run.sh | 38 + .../exp00_baseline-rerun_exp27/save_model.py | 52 + .../exp00_baseline-rerun_exp27/train_gpt.py | 1340 ++++++++++++++ .../Inference.ipynb | 238 +++ .../exp01_partial-rope_from-exp27/README.md | 33 + .../exp01_partial-rope_from-exp27/logs.txt | 0 .../exp01_partial-rope_from-exp27/run.sh | 40 + .../save_model.py | 68 + .../train_gpt.py | 1349 ++++++++++++++ .../Inference.ipynb | 281 +++ .../exp01b_ln-scale-only_from-exp27/README.md | 11 + .../exp01b_ln-scale-only_from-exp27/logs.txt | 0 .../exp01b_ln-scale-only_from-exp27/run.sh | 37 + .../save_model.py | 25 + .../train_gpt.py | 1341 ++++++++++++++ .../Inference.ipynb | 281 +++ .../exp01c_ema-only_from-exp27/README.md | 11 + .../exp01c_ema-only_from-exp27/logs.txt | 0 .../phase3/exp01c_ema-only_from-exp27/run.sh | 37 + .../exp01c_ema-only_from-exp27/save_model.py | 26 + .../exp01c_ema-only_from-exp27/train_gpt.py | 1335 ++++++++++++++ .../Inference.ipynb | 281 +++ .../exp01d_xsa-only_from-exp27/README.md | 11 + .../exp01d_xsa-only_from-exp27/logs.txt | 0 .../phase3/exp01d_xsa-only_from-exp27/run.sh | 38 + .../exp01d_xsa-only_from-exp27/save_model.py | 26 + .../exp01d_xsa-only_from-exp27/train_gpt.py | 1352 +++++++++++++++ .../exp02_ln-scale_from-exp01/Inference.ipynb | 238 +++ .../exp02_ln-scale_from-exp01/README.md | 24 + .../phase3/exp02_ln-scale_from-exp01/logs.txt | 0 .../phase3/exp02_ln-scale_from-exp01/run.sh | 39 + .../exp02_ln-scale_from-exp01/save_model.py | 53 + .../exp02_ln-scale_from-exp01/train_gpt.py | 1350 +++++++++++++++ .../exp03_ema_from-exp02/Inference.ipynb | 238 +++ records/phase3/exp03_ema_from-exp02/README.md | 26 + records/phase3/exp03_ema_from-exp02/logs.txt | 0 records/phase3/exp03_ema_from-exp02/run.sh | 39 + .../phase3/exp03_ema_from-exp02/save_model.py | 54 + .../phase3/exp03_ema_from-exp02/train_gpt.py | 1345 ++++++++++++++ .../exp04_xsa4_from-exp03/Inference.ipynb | 238 +++ .../phase3/exp04_xsa4_from-exp03/README.md | 28 + records/phase3/exp04_xsa4_from-exp03/logs.txt | 0 records/phase3/exp04_xsa4_from-exp03/run.sh | 41 + .../exp04_xsa4_from-exp03/save_model.py | 55 + .../phase3/exp04_xsa4_from-exp03/train_gpt.py | 1362 +++++++++++++++ .../exp05_late-qat_from-exp04/Inference.ipynb | 238 +++ .../exp05_late-qat_from-exp04/README.md | 26 + .../phase3/exp05_late-qat_from-exp04/logs.txt | 0 .../phase3/exp05_late-qat_from-exp04/run.sh | 42 + .../exp05_late-qat_from-exp04/save_model.py | 56 + .../exp05_late-qat_from-exp04/train_gpt.py | 1380 +++++++++++++++ .../exp06_gptq_from-exp05/Inference.ipynb | 238 +++ .../phase3/exp06_gptq_from-exp05/README.md | 26 + records/phase3/exp06_gptq_from-exp05/logs.txt | 0 records/phase3/exp06_gptq_from-exp05/run.sh | 45 + .../exp06_gptq_from-exp05/save_model.py | 57 + .../phase3/exp06_gptq_from-exp05/train_gpt.py | 1485 ++++++++++++++++ .../Inference.ipynb | 238 +++ .../exp07_parallel-muon_from-exp06/README.md | 26 + .../exp07_parallel-muon_from-exp06/logs.txt | 0 .../exp07_parallel-muon_from-exp06/run.sh | 45 + .../save_model.py | 57 + .../train_gpt.py | 1542 +++++++++++++++++ 66 files changed, 19185 insertions(+) create mode 100644 records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb create mode 100644 records/phase3/exp00_baseline-rerun_exp27/README.md create mode 100644 records/phase3/exp00_baseline-rerun_exp27/logs.txt create mode 100755 records/phase3/exp00_baseline-rerun_exp27/run.sh create mode 100644 records/phase3/exp00_baseline-rerun_exp27/save_model.py create mode 100644 records/phase3/exp00_baseline-rerun_exp27/train_gpt.py create mode 100644 records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb create mode 100644 records/phase3/exp01_partial-rope_from-exp27/README.md create mode 100644 records/phase3/exp01_partial-rope_from-exp27/logs.txt create mode 100755 records/phase3/exp01_partial-rope_from-exp27/run.sh create mode 100644 records/phase3/exp01_partial-rope_from-exp27/save_model.py create mode 100644 records/phase3/exp01_partial-rope_from-exp27/train_gpt.py create mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb create mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/README.md create mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt create mode 100755 records/phase3/exp01b_ln-scale-only_from-exp27/run.sh create mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py create mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py create mode 100644 records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb create mode 100644 records/phase3/exp01c_ema-only_from-exp27/README.md create mode 100644 records/phase3/exp01c_ema-only_from-exp27/logs.txt create mode 100755 records/phase3/exp01c_ema-only_from-exp27/run.sh create mode 100644 records/phase3/exp01c_ema-only_from-exp27/save_model.py create mode 100644 records/phase3/exp01c_ema-only_from-exp27/train_gpt.py create mode 100644 records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb create mode 100644 records/phase3/exp01d_xsa-only_from-exp27/README.md create mode 100644 records/phase3/exp01d_xsa-only_from-exp27/logs.txt create mode 100755 records/phase3/exp01d_xsa-only_from-exp27/run.sh create mode 100644 records/phase3/exp01d_xsa-only_from-exp27/save_model.py create mode 100644 records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py create mode 100644 records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb create mode 100644 records/phase3/exp02_ln-scale_from-exp01/README.md create mode 100644 records/phase3/exp02_ln-scale_from-exp01/logs.txt create mode 100755 records/phase3/exp02_ln-scale_from-exp01/run.sh create mode 100644 records/phase3/exp02_ln-scale_from-exp01/save_model.py create mode 100644 records/phase3/exp02_ln-scale_from-exp01/train_gpt.py create mode 100644 records/phase3/exp03_ema_from-exp02/Inference.ipynb create mode 100644 records/phase3/exp03_ema_from-exp02/README.md create mode 100644 records/phase3/exp03_ema_from-exp02/logs.txt create mode 100755 records/phase3/exp03_ema_from-exp02/run.sh create mode 100644 records/phase3/exp03_ema_from-exp02/save_model.py create mode 100644 records/phase3/exp03_ema_from-exp02/train_gpt.py create mode 100644 records/phase3/exp04_xsa4_from-exp03/Inference.ipynb create mode 100644 records/phase3/exp04_xsa4_from-exp03/README.md create mode 100644 records/phase3/exp04_xsa4_from-exp03/logs.txt create mode 100755 records/phase3/exp04_xsa4_from-exp03/run.sh create mode 100644 records/phase3/exp04_xsa4_from-exp03/save_model.py create mode 100644 records/phase3/exp04_xsa4_from-exp03/train_gpt.py create mode 100644 records/phase3/exp05_late-qat_from-exp04/Inference.ipynb create mode 100644 records/phase3/exp05_late-qat_from-exp04/README.md create mode 100644 records/phase3/exp05_late-qat_from-exp04/logs.txt create mode 100755 records/phase3/exp05_late-qat_from-exp04/run.sh create mode 100644 records/phase3/exp05_late-qat_from-exp04/save_model.py create mode 100644 records/phase3/exp05_late-qat_from-exp04/train_gpt.py create mode 100644 records/phase3/exp06_gptq_from-exp05/Inference.ipynb create mode 100644 records/phase3/exp06_gptq_from-exp05/README.md create mode 100644 records/phase3/exp06_gptq_from-exp05/logs.txt create mode 100755 records/phase3/exp06_gptq_from-exp05/run.sh create mode 100644 records/phase3/exp06_gptq_from-exp05/save_model.py create mode 100644 records/phase3/exp06_gptq_from-exp05/train_gpt.py create mode 100644 records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb create mode 100644 records/phase3/exp07_parallel-muon_from-exp06/README.md create mode 100644 records/phase3/exp07_parallel-muon_from-exp06/logs.txt create mode 100755 records/phase3/exp07_parallel-muon_from-exp06/run.sh create mode 100644 records/phase3/exp07_parallel-muon_from-exp06/save_model.py create mode 100644 records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py diff --git a/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb b/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb new file mode 100644 index 0000000000..8a406b185c --- /dev/null +++ b/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import json\n", + "import io\n", + "import math\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "# Config — change these paths as needed\n", + "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", + "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", + "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build model\n", + "model = tg.GPT(\n", + " vocab_size=args.vocab_size,\n", + " num_layers=args.num_layers,\n", + " model_dim=args.model_dim,\n", + " num_heads=args.num_heads,\n", + " num_kv_heads=args.num_kv_heads,\n", + " mlp_mult=args.mlp_mult,\n", + " tie_embeddings=args.tie_embeddings,\n", + " tied_embed_init_std=args.tied_embed_init_std,\n", + " logit_softcap=args.logit_softcap,\n", + " rope_base=args.rope_base,\n", + " qk_gain_init=args.qk_gain_init,\n", + " bigram_vocab_size=args.bigram_vocab_size,\n", + " bigram_dim=args.bigram_dim,\n", + " unique_layers=args.unique_layers,\n", + ")\n", + "\n", + "# Load state dict\n", + "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model = model.to(DEVICE).eval()\n", + "\n", + "n_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp00_baseline-rerun_exp27/README.md b/records/phase3/exp00_baseline-rerun_exp27/README.md new file mode 100644 index 0000000000..b100a72f00 --- /dev/null +++ b/records/phase3/exp00_baseline-rerun_exp27/README.md @@ -0,0 +1,22 @@ +# Baseline Rerun (Exp27 on A100) + +## Purpose + +Control experiment. Establishes the exact val_bpb of exp27 on A100 hardware with 2 seeds to measure variance. All phase3 experiments are compared against this. + +## Protocol + +Run twice: +```bash +SEED=42 bash run.sh +SEED=1337 bash run.sh +``` + +## Results + +| Seed | val_bpb | train_time | steps | +|------|---------|------------|-------| +| 42 | TBD | TBD | TBD | +| 1337 | TBD | TBD | TBD | +| **mean** | TBD | | | +| **std** | TBD | | | diff --git a/records/phase3/exp00_baseline-rerun_exp27/logs.txt b/records/phase3/exp00_baseline-rerun_exp27/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp00_baseline-rerun_exp27/run.sh b/records/phase3/exp00_baseline-rerun_exp27/run.sh new file mode 100755 index 0000000000..9ce33c0470 --- /dev/null +++ b/records/phase3/exp00_baseline-rerun_exp27/run.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# ============================================================ +# Exp00: Baseline rerun of Exp27 on A100 +# Control experiment. Run with SEED=42 and SEED=1337 to +# establish variance baseline for all phase3 comparisons. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp00_baseline-rerun_exp27" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export RUN_ID="${EXP_NAME}_seed${SEED}" +echo "=== ${EXP_NAME} seed=${SEED} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp00_baseline-rerun_exp27/save_model.py b/records/phase3/exp00_baseline-rerun_exp27/save_model.py new file mode 100644 index 0000000000..0188ff381c --- /dev/null +++ b/records/phase3/exp00_baseline-rerun_exp27/save_model.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp00_baseline-rerun_exp27.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py b/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py new file mode 100644 index 0000000000..1b3a9d0087 --- /dev/null +++ b/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py @@ -0,0 +1,1340 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb b/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb new file mode 100644 index 0000000000..cdb5ae3d87 --- /dev/null +++ b/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\n# Config — change these paths as needed\nEXPERIMENT_DIR = \"records/phase3/exp01_partial-rope_from-exp27\"\nMODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp01_partial-rope_from-exp27/README.md b/records/phase3/exp01_partial-rope_from-exp27/README.md new file mode 100644 index 0000000000..edd4cb6204 --- /dev/null +++ b/records/phase3/exp01_partial-rope_from-exp27/README.md @@ -0,0 +1,33 @@ +# Partial RoPE (25% of Head Dimensions) + +## Score: val_bpb = TBD + +## Hypothesis + +Applying RoPE to only 16/64 head dimensions allows the remaining 48 dimensions to learn position-independent semantic similarity. Community data shows ~0.005 BPB improvement. Zero additional parameters, zero compute overhead. + +## Change from exp27 + +Single architectural change in `CausalSelfAttention`: +- `rope_frac=0.25`: RoPE applied to first 16 dims, remaining 48 pass through unchanged +- `Rotary` module initialized with `dim=16` instead of `dim=64` + +## Architecture + +| Parameter | Value | +|-----------|-------| +| num_layers | 11 (10 unique) | +| model_dim | 512 | +| num_heads | 8, num_kv_heads | 4 | +| head_dim | 64 (16 RoPE + 48 pass-through) | +| mlp_mult | 3.0 (hidden=1536) | +| mlp_activation | LeakyReLU(0.5)² | +| rope_frac | 0.25 | + +## Expected Impact + +~0.005 BPB improvement over exp27 baseline (1.3345 → ~1.330) + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp01_partial-rope_from-exp27/logs.txt b/records/phase3/exp01_partial-rope_from-exp27/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp01_partial-rope_from-exp27/run.sh b/records/phase3/exp01_partial-rope_from-exp27/run.sh new file mode 100755 index 0000000000..0e3274fbc3 --- /dev/null +++ b/records/phase3/exp01_partial-rope_from-exp27/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# ============================================================ +# Exp01: Partial RoPE (25% of head dims) — from Exp27 +# Apply RoPE to only 16/64 head dimensions. Remaining 48 dims +# attend without position encoding, learning semantic similarity +# independent of distance. Zero extra parameters. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp01_partial-rope_from-exp27" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01_partial-rope_from-exp27/save_model.py b/records/phase3/exp01_partial-rope_from-exp27/save_model.py new file mode 100644 index 0000000000..a978e39ce5 --- /dev/null +++ b/records/phase3/exp01_partial-rope_from-exp27/save_model.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp01_partial-rope_from-exp27. + +Usage (after training): + python save_model.py [--model-pt final_model.pt] [--output-dir model_checkpoint] +""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + # Import training script for hyperparameters + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "train_seq_len": hp.train_seq_len, + } + + config_path = os.path.join(args.output_dir, "config.json") + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + print(f"Saved config to {config_path}") + + if os.path.exists(args.model_pt): + dst = os.path.join(args.output_dir, "model.pt") + shutil.copy2(args.model_pt, dst) + print(f"Copied model to {dst}") + else: + print(f"Warning: {args.model_pt} not found. Run training first.") + + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + dst = os.path.join(args.output_dir, "model_quant.ptz") + shutil.copy2(quant_path, dst) + print(f"Copied quantized model to {dst}") + + print(f"\nCheckpoint saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py b/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py new file mode 100644 index 0000000000..cc3524c67e --- /dev/null +++ b/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py @@ -0,0 +1,1349 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb b/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb new file mode 100644 index 0000000000..8a406b185c --- /dev/null +++ b/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import json\n", + "import io\n", + "import math\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "# Config — change these paths as needed\n", + "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", + "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", + "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build model\n", + "model = tg.GPT(\n", + " vocab_size=args.vocab_size,\n", + " num_layers=args.num_layers,\n", + " model_dim=args.model_dim,\n", + " num_heads=args.num_heads,\n", + " num_kv_heads=args.num_kv_heads,\n", + " mlp_mult=args.mlp_mult,\n", + " tie_embeddings=args.tie_embeddings,\n", + " tied_embed_init_std=args.tied_embed_init_std,\n", + " logit_softcap=args.logit_softcap,\n", + " rope_base=args.rope_base,\n", + " qk_gain_init=args.qk_gain_init,\n", + " bigram_vocab_size=args.bigram_vocab_size,\n", + " bigram_dim=args.bigram_dim,\n", + " unique_layers=args.unique_layers,\n", + ")\n", + "\n", + "# Load state dict\n", + "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model = model.to(DEVICE).eval()\n", + "\n", + "n_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/README.md b/records/phase3/exp01b_ln-scale-only_from-exp27/README.md new file mode 100644 index 0000000000..ecd371d8b8 --- /dev/null +++ b/records/phase3/exp01b_ln-scale-only_from-exp27/README.md @@ -0,0 +1,11 @@ +# LN Scale Only (Ablation) + +## Purpose + +Isolated ablation: measures the marginal effect of LN Scale (`1/√(layer+1)`) from exp27 baseline, without partial RoPE or any other changes. + +## Results + +| Seed | val_bpb | delta vs exp00 | +|------|---------|----------------| +| 42 | TBD | TBD | diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt b/records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh b/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh new file mode 100755 index 0000000000..ca163a3bd4 --- /dev/null +++ b/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# ============================================================ +# Exp01b: LN Scale ONLY (ablation) — from Exp27 +# Isolated test of 1/√(layer+1) damping without partial RoPE. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp01b_ln-scale-only_from-exp27" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export RUN_ID="${EXP_NAME}_seed${SEED}" +echo "=== ${EXP_NAME} seed=${SEED} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py b/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py new file mode 100644 index 0000000000..c3b4fb7561 --- /dev/null +++ b/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp01b_ln-scale-only_from-exp27.""" +import argparse, json, os, sys, shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + hp = tg.Hyperparameters() + config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py b/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py new file mode 100644 index 0000000000..ec0da3a940 --- /dev/null +++ b/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py @@ -0,0 +1,1341 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, layer_idx: int = 0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, layer_idx=i) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb b/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb new file mode 100644 index 0000000000..8a406b185c --- /dev/null +++ b/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import json\n", + "import io\n", + "import math\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "# Config — change these paths as needed\n", + "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", + "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", + "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build model\n", + "model = tg.GPT(\n", + " vocab_size=args.vocab_size,\n", + " num_layers=args.num_layers,\n", + " model_dim=args.model_dim,\n", + " num_heads=args.num_heads,\n", + " num_kv_heads=args.num_kv_heads,\n", + " mlp_mult=args.mlp_mult,\n", + " tie_embeddings=args.tie_embeddings,\n", + " tied_embed_init_std=args.tied_embed_init_std,\n", + " logit_softcap=args.logit_softcap,\n", + " rope_base=args.rope_base,\n", + " qk_gain_init=args.qk_gain_init,\n", + " bigram_vocab_size=args.bigram_vocab_size,\n", + " bigram_dim=args.bigram_dim,\n", + " unique_layers=args.unique_layers,\n", + ")\n", + "\n", + "# Load state dict\n", + "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model = model.to(DEVICE).eval()\n", + "\n", + "n_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp01c_ema-only_from-exp27/README.md b/records/phase3/exp01c_ema-only_from-exp27/README.md new file mode 100644 index 0000000000..47f860aad8 --- /dev/null +++ b/records/phase3/exp01c_ema-only_from-exp27/README.md @@ -0,0 +1,11 @@ +# EMA Only (Ablation) + +## Purpose + +Isolated ablation: measures the marginal effect of EMA (decay=0.997) replacing SWA, from exp27 baseline, without partial RoPE or LN Scale. + +## Results + +| Seed | val_bpb | delta vs exp00 | +|------|---------|----------------| +| 42 | TBD | TBD | diff --git a/records/phase3/exp01c_ema-only_from-exp27/logs.txt b/records/phase3/exp01c_ema-only_from-exp27/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp01c_ema-only_from-exp27/run.sh b/records/phase3/exp01c_ema-only_from-exp27/run.sh new file mode 100755 index 0000000000..6ed223a78e --- /dev/null +++ b/records/phase3/exp01c_ema-only_from-exp27/run.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# ============================================================ +# Exp01c: EMA ONLY (ablation) — from Exp27 +# Isolated test of EMA (decay=0.997) replacing SWA. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp01c_ema-only_from-exp27" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export RUN_ID="${EXP_NAME}_seed${SEED}" +echo "=== ${EXP_NAME} seed=${SEED} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01c_ema-only_from-exp27/save_model.py b/records/phase3/exp01c_ema-only_from-exp27/save_model.py new file mode 100644 index 0000000000..f28ff2c12e --- /dev/null +++ b/records/phase3/exp01c_ema-only_from-exp27/save_model.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp01c_ema-only_from-exp27.""" +import argparse, json, os, sys, shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + hp = tg.Hyperparameters() + config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} + config["ema_decay"] = hp.ema_decay + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py b/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py new file mode 100644 index 0000000000..3bf2f75f82 --- /dev/null +++ b/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py @@ -0,0 +1,1335 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb b/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb new file mode 100644 index 0000000000..8a406b185c --- /dev/null +++ b/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import json\n", + "import io\n", + "import math\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "# Config — change these paths as needed\n", + "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", + "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", + "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build model\n", + "model = tg.GPT(\n", + " vocab_size=args.vocab_size,\n", + " num_layers=args.num_layers,\n", + " model_dim=args.model_dim,\n", + " num_heads=args.num_heads,\n", + " num_kv_heads=args.num_kv_heads,\n", + " mlp_mult=args.mlp_mult,\n", + " tie_embeddings=args.tie_embeddings,\n", + " tied_embed_init_std=args.tied_embed_init_std,\n", + " logit_softcap=args.logit_softcap,\n", + " rope_base=args.rope_base,\n", + " qk_gain_init=args.qk_gain_init,\n", + " bigram_vocab_size=args.bigram_vocab_size,\n", + " bigram_dim=args.bigram_dim,\n", + " unique_layers=args.unique_layers,\n", + ")\n", + "\n", + "# Load state dict\n", + "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model = model.to(DEVICE).eval()\n", + "\n", + "n_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp01d_xsa-only_from-exp27/README.md b/records/phase3/exp01d_xsa-only_from-exp27/README.md new file mode 100644 index 0000000000..8f028ed8ad --- /dev/null +++ b/records/phase3/exp01d_xsa-only_from-exp27/README.md @@ -0,0 +1,11 @@ +# XSA Only (Ablation) + +## Purpose + +Isolated ablation: measures the marginal effect of XSA on last 4 layers from exp27 baseline, keeping SWA and full RoPE. Critical to test since community says XSA without EMA can hurt (-0.003 BPB). + +## Results + +| Seed | val_bpb | delta vs exp00 | +|------|---------|----------------| +| 42 | TBD | TBD | diff --git a/records/phase3/exp01d_xsa-only_from-exp27/logs.txt b/records/phase3/exp01d_xsa-only_from-exp27/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp01d_xsa-only_from-exp27/run.sh b/records/phase3/exp01d_xsa-only_from-exp27/run.sh new file mode 100755 index 0000000000..952bb8b9a8 --- /dev/null +++ b/records/phase3/exp01d_xsa-only_from-exp27/run.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# ============================================================ +# Exp01d: XSA ONLY (ablation) — from Exp27 +# Isolated test of XSA on last 4 layers, keeping SWA and full RoPE. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp01d_xsa-only_from-exp27" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export XSA_LAST_N=4 +export RUN_ID="${EXP_NAME}_seed${SEED}" +echo "=== ${EXP_NAME} seed=${SEED} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01d_xsa-only_from-exp27/save_model.py b/records/phase3/exp01d_xsa-only_from-exp27/save_model.py new file mode 100644 index 0000000000..68f98f6849 --- /dev/null +++ b/records/phase3/exp01d_xsa-only_from-exp27/save_model.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp01d_xsa-only_from-exp27.""" +import argparse, json, os, sys, shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + hp = tg.Hyperparameters() + config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} + config["xsa_last_n"] = hp.xsa_last_n + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py b/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py new file mode 100644 index 0000000000..84aa32d278 --- /dev/null +++ b/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py @@ -0,0 +1,1352 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(rep, dim=1) + else: + v_expanded = v + y = y - v_expanded / seqlen + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + use_xsa=(i >= n_unique - xsa_last_n)) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb b/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb new file mode 100644 index 0000000000..64d060f768 --- /dev/null +++ b/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp02_ln-scale_from-exp01\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp02_ln-scale_from-exp01/README.md b/records/phase3/exp02_ln-scale_from-exp01/README.md new file mode 100644 index 0000000000..aa46634959 --- /dev/null +++ b/records/phase3/exp02_ln-scale_from-exp01/README.md @@ -0,0 +1,24 @@ +# LN Scale (1/√(layer+1)) Damping + +## Score: val_bpb = TBD + +## Hypothesis + +Post-RMSNorm output scaling by `1/√(layer_idx+1)` damps deeper layer contributions, preventing later layers from overwriting early representations. Community data shows ~0.003 BPB improvement. Zero additional parameters. + +## Changes from exp01 + +- `Block.__init__` now takes `layer_idx`, computes `self.ln_scale = 1/√(layer_idx+1)` +- `Block.forward` multiplies both `attn_scale` and `mlp_scale` residuals by `ln_scale` + +## Architecture + +Inherits from exp01 (Partial RoPE 25%) + adds LN Scale damping. + +## Expected Impact + +~0.003 BPB improvement over exp01. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp02_ln-scale_from-exp01/logs.txt b/records/phase3/exp02_ln-scale_from-exp01/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp02_ln-scale_from-exp01/run.sh b/records/phase3/exp02_ln-scale_from-exp01/run.sh new file mode 100755 index 0000000000..c45fcb6c91 --- /dev/null +++ b/records/phase3/exp02_ln-scale_from-exp01/run.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# ============================================================ +# Exp02: LN Scale (1/√(layer+1)) — from Exp01 +# Damps deeper layer contributions to prevent overwriting early +# representations. Multiplies attn/mlp residual by 1/√(layer+1). +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp02_ln-scale_from-exp01" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export SWA_START_FRAC=0.2 +export SWA_EVERY=50 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp02_ln-scale_from-exp01/save_model.py b/records/phase3/exp02_ln-scale_from-exp01/save_model.py new file mode 100644 index 0000000000..6f5093b292 --- /dev/null +++ b/records/phase3/exp02_ln-scale_from-exp01/save_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp02_ln-scale_from-exp01.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py b/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py new file mode 100644 index 0000000000..178d3c7b8d --- /dev/null +++ b/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py @@ -0,0 +1,1350 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac, layer_idx=i) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp03_ema_from-exp02/Inference.ipynb b/records/phase3/exp03_ema_from-exp02/Inference.ipynb new file mode 100644 index 0000000000..9d1e51e217 --- /dev/null +++ b/records/phase3/exp03_ema_from-exp02/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp03_ema_from-exp02\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp03_ema_from-exp02/README.md b/records/phase3/exp03_ema_from-exp02/README.md new file mode 100644 index 0000000000..d2ab07c469 --- /dev/null +++ b/records/phase3/exp03_ema_from-exp02/README.md @@ -0,0 +1,26 @@ +# EMA (Exponential Moving Average) Replacing SWA + +## Score: val_bpb = TBD + +## Hypothesis + +EMA with decay=0.997 outperforms SWA by ~0.003 BPB (community verified, 3-seed). EMA updates every step (smoother averaging) vs SWA's periodic snapshots. Critical prerequisite for XSA (exp04). + +## Changes from exp02 + +- Removed SWA hyperparameters (`swa_enabled`, `swa_start_frac`, `swa_every`) +- Added `ema_enabled=True`, `ema_decay=0.997` +- EMA shadow weights updated every training step: `ema = decay * ema + (1-decay) * weights` +- EMA weights loaded before final eval (replaces SWA averaging) + +## Architecture + +Inherits from exp02 (Partial RoPE + LN Scale) + EMA replacing SWA. + +## Expected Impact + +~0.003 BPB improvement over exp02. Also unlocks synergy with XSA. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp03_ema_from-exp02/logs.txt b/records/phase3/exp03_ema_from-exp02/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp03_ema_from-exp02/run.sh b/records/phase3/exp03_ema_from-exp02/run.sh new file mode 100755 index 0000000000..37aa1a4360 --- /dev/null +++ b/records/phase3/exp03_ema_from-exp02/run.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# ============================================================ +# Exp03: EMA (decay=0.997) replacing SWA — from Exp02 +# Exponential moving average of weights updated every step. +# Community confirmed EMA > SWA by 0.003 BPB. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp03_ema_from-exp02" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp03_ema_from-exp02/save_model.py b/records/phase3/exp03_ema_from-exp02/save_model.py new file mode 100644 index 0000000000..b8220e3172 --- /dev/null +++ b/records/phase3/exp03_ema_from-exp02/save_model.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp03_ema_from-exp02.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "ema_decay": hp.ema_decay, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp03_ema_from-exp02/train_gpt.py b/records/phase3/exp03_ema_from-exp02/train_gpt.py new file mode 100644 index 0000000000..ce1c19fea2 --- /dev/null +++ b/records/phase3/exp03_ema_from-exp02/train_gpt.py @@ -0,0 +1,1345 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac, layer_idx=i) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb b/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb new file mode 100644 index 0000000000..dec1276ff7 --- /dev/null +++ b/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp04_xsa4_from-exp03\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp04_xsa4_from-exp03/README.md b/records/phase3/exp04_xsa4_from-exp03/README.md new file mode 100644 index 0000000000..3d051e28d8 --- /dev/null +++ b/records/phase3/exp04_xsa4_from-exp03/README.md @@ -0,0 +1,28 @@ +# Exclusive Self-Attention (XSA) on Last 4 Layers + +## Score: val_bpb = TBD + +## Hypothesis + +XSA removes the self-value bias from attention output by subtracting `v_i / seq_len` from each position. This forces the model to rely on information from other tokens rather than copying its own value. Applied to last 4 of 10 unique layers (not all). Zero additional parameters, slight compute overhead. + +Community shows XSA + EMA is the "prerequisite stack" for all frontier techniques. EMA without XSA loses 0.003 BPB; EMA with XSA gains 0.003 BPB. + +## Changes from exp03 + +- `CausalSelfAttention` gains `use_xsa` flag +- When enabled, subtracts `v_expanded / seqlen` after SDPA (removes self-value contribution) +- `Block` passes `use_xsa` to attention +- `GPT` enables XSA on last `xsa_last_n=4` physical blocks + +## Architecture + +Inherits from exp03 (Partial RoPE + LN Scale + EMA) + XSA on last 4 layers. + +## Expected Impact + +~0.01-0.02 BPB improvement over exp03. Also synergizes with EMA. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp04_xsa4_from-exp03/logs.txt b/records/phase3/exp04_xsa4_from-exp03/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp04_xsa4_from-exp03/run.sh b/records/phase3/exp04_xsa4_from-exp03/run.sh new file mode 100755 index 0000000000..d673b67919 --- /dev/null +++ b/records/phase3/exp04_xsa4_from-exp03/run.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# ============================================================ +# Exp04: XSA on last 4 layers — from Exp03 +# Exclusive Self-Attention: subtract self-value from attention +# output, forcing reliance on cross-token information. +# Critical synergy with EMA (exp03). +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp04_xsa4_from-exp03" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export XSA_LAST_N=4 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp04_xsa4_from-exp03/save_model.py b/records/phase3/exp04_xsa4_from-exp03/save_model.py new file mode 100644 index 0000000000..9d7cd8b817 --- /dev/null +++ b/records/phase3/exp04_xsa4_from-exp03/save_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp04_xsa4_from-exp03.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "ema_decay": hp.ema_decay, + "xsa_last_n": hp.xsa_last_n, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp04_xsa4_from-exp03/train_gpt.py b/records/phase3/exp04_xsa4_from-exp03/train_gpt.py new file mode 100644 index 0000000000..ee05fcd60f --- /dev/null +++ b/records/phase3/exp04_xsa4_from-exp03/train_gpt.py @@ -0,0 +1,1362 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: subtract self-value to force reliance on cross-token information + if self.use_xsa: + # Expand v for GQA if needed + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(rep, dim=1) + else: + v_expanded = v + # Self-attention weight for diagonal is 1/seq_len in expectation, + # but we subtract the actual self-value contribution: v_i / scale + # Simpler approach: subtract v_i projected orthogonally + y = y - v_expanded / seqlen + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb b/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb new file mode 100644 index 0000000000..badc853e52 --- /dev/null +++ b/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp05_late-qat_from-exp04\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp05_late-qat_from-exp04/README.md b/records/phase3/exp05_late-qat_from-exp04/README.md new file mode 100644 index 0000000000..c398acb06a --- /dev/null +++ b/records/phase3/exp05_late-qat_from-exp04/README.md @@ -0,0 +1,26 @@ +# Late Quantization-Aware Training (QAT) + +## Score: val_bpb = TBD + +## Hypothesis + +Applying fake quantization (quantize → dequantize with STE) only when `lr_scale < 0.1` (final ~4% of training) lets the model learn robust quantized configurations without corrupting early convergence. Community shows this closes the quantization gap from ~0.023 to ~0.007 BPB. + +## Changes from exp04 + +- Added `qat_threshold=0.1` hyperparameter +- Before each forward pass during warmdown (when `lr_scale < threshold`): fake-quantize all 2D weight matrices using the same int5/int6 clip ranges as post-training quantization +- Uses `_classify_param` to match int5 (MLP) vs int6 (attn) clip ranges +- STE: gradient flows through the quantized weights unchanged + +## Architecture + +Inherits from exp04 (Partial RoPE + LN Scale + EMA + XSA) + Late QAT. + +## Expected Impact + +Reduce quantization penalty from ~0.02 to ~0.007 BPB. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp05_late-qat_from-exp04/logs.txt b/records/phase3/exp05_late-qat_from-exp04/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp05_late-qat_from-exp04/run.sh b/records/phase3/exp05_late-qat_from-exp04/run.sh new file mode 100755 index 0000000000..cef8bfe0c1 --- /dev/null +++ b/records/phase3/exp05_late-qat_from-exp04/run.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# ============================================================ +# Exp05: Late QAT (STE when lr_scale < 0.1) — from Exp04 +# Quantization-Aware Training via Straight-Through Estimator. +# Activates only in final ~4% of training (warmdown tail). +# Closes quantization gap from ~0.023 to ~0.007 BPB. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp05_late-qat_from-exp04" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export XSA_LAST_N=4 +export QAT_THRESHOLD=0.1 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp05_late-qat_from-exp04/save_model.py b/records/phase3/exp05_late-qat_from-exp04/save_model.py new file mode 100644 index 0000000000..d82d838f9e --- /dev/null +++ b/records/phase3/exp05_late-qat_from-exp04/save_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp05_late-qat_from-exp04.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "ema_decay": hp.ema_decay, + "xsa_last_n": hp.xsa_last_n, + "qat_threshold": hp.qat_threshold, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp05_late-qat_from-exp04/train_gpt.py b/records/phase3/exp05_late-qat_from-exp04/train_gpt.py new file mode 100644 index 0000000000..26fdc29926 --- /dev/null +++ b/records/phase3/exp05_late-qat_from-exp04/train_gpt.py @@ -0,0 +1,1380 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: subtract self-value to force reliance on cross-token information + if self.use_xsa: + # Expand v for GQA if needed + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(rep, dim=1) + else: + v_expanded = v + # Self-attention weight for diagonal is 1/seq_len in expectation, + # but we subtract the actual self-value contribution: v_i / scale + # Simpler approach: subtract v_i projected orthogonally + y = y - v_expanded / seqlen + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: apply fake quantization via STE when lr_scale < threshold + qat_active = args.qat_threshold > 0 and scale < args.qat_threshold + if qat_active: + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + cat = _classify_param(name) + clip = 15 if cat == "mlp" else 31 # match int5/int6 + with torch.no_grad(): + t32 = param.float() + row_max = t32.abs().amax(dim=1).clamp_min(1e-12) + s = row_max / clip + q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) + dq = (q * s[:, None]).to(param.dtype) + # STE: replace param data but keep grad flowing + param.data.copy_(dq) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp06_gptq_from-exp05/Inference.ipynb b/records/phase3/exp06_gptq_from-exp05/Inference.ipynb new file mode 100644 index 0000000000..7874850879 --- /dev/null +++ b/records/phase3/exp06_gptq_from-exp05/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp06_gptq_from-exp05\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp06_gptq_from-exp05/README.md b/records/phase3/exp06_gptq_from-exp05/README.md new file mode 100644 index 0000000000..55ca35476f --- /dev/null +++ b/records/phase3/exp06_gptq_from-exp05/README.md @@ -0,0 +1,26 @@ +# GPTQ Second-Order Quantization + +## Score: val_bpb = TBD + +## Hypothesis + +GPTQ uses Hessian information (X^T X from calibration data) to minimize quantization error via column-by-column second-order optimization. Community shows full GPTQ at 1.1154 vs lite at 1.1228 — ~0.007 BPB improvement over naive quantization. + +## Changes from exp05 + +- Added `gptq_quantize_layer()`: Cholesky-based GPTQ with per-row scaling +- Hessian collection via forward hooks on calibration data (8 batches) +- `mixed_quantize_int6` now accepts `gptq_hessians` dict — uses GPTQ when Hessian available, falls back to naive otherwise +- AWQ + GPTQ stacked (AWQ scales columns before GPTQ quantizes) + +## Architecture + +Inherits from exp05 (Partial RoPE + LN Scale + EMA + XSA + Late QAT) + GPTQ post-training. + +## Expected Impact + +~0.007 BPB improvement in quantized model quality. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp06_gptq_from-exp05/logs.txt b/records/phase3/exp06_gptq_from-exp05/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp06_gptq_from-exp05/run.sh b/records/phase3/exp06_gptq_from-exp05/run.sh new file mode 100755 index 0000000000..8c7537b451 --- /dev/null +++ b/records/phase3/exp06_gptq_from-exp05/run.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# ============================================================ +# Exp06: GPTQ second-order quantization — from Exp05 +# Replaces naive per-row quantization with Hessian-aware GPTQ. +# Uses calibration data to minimize quantization error via +# second-order column-by-column optimization. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp06_gptq_from-exp05" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export XSA_LAST_N=4 +export QAT_THRESHOLD=0.1 +export USE_GPTQ=1 +export GPTQ_DAMP=0.01 +export GPTQ_CALIB_BATCHES=8 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp06_gptq_from-exp05/save_model.py b/records/phase3/exp06_gptq_from-exp05/save_model.py new file mode 100644 index 0000000000..e5abe53bfe --- /dev/null +++ b/records/phase3/exp06_gptq_from-exp05/save_model.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp06_gptq_from-exp05.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "ema_decay": hp.ema_decay, + "xsa_last_n": hp.xsa_last_n, + "qat_threshold": hp.qat_threshold, + "use_gptq": hp.use_gptq, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp06_gptq_from-exp05/train_gpt.py b/records/phase3/exp06_gptq_from-exp05/train_gpt.py new file mode 100644 index 0000000000..bcb506608a --- /dev/null +++ b/records/phase3/exp06_gptq_from-exp05/train_gpt.py @@ -0,0 +1,1485 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) + use_gptq = bool(int(os.environ.get("USE_GPTQ", "1"))) + gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.01)) + gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 8)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + + +def gptq_quantize_layer(W: Tensor, H: Tensor, clip_range: int = 31, damp: float = 0.01) -> tuple[Tensor, Tensor]: + """GPTQ: second-order weight quantization using Hessian information. + W: [out_features, in_features] weight matrix + H: [in_features, in_features] Hessian (X^T X from calibration) + Returns (q_int8, scale_fp16) same as quantize_intN_per_row. + """ + W = W.float() + out_dim, in_dim = W.shape + # Per-row scale (same as naive) + row_max = W.abs().amax(dim=1).clamp_min(1e-12) + scale = (row_max / clip_range).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + scale_f = scale.float() + + # Dampen Hessian diagonal for numerical stability + diag = torch.diag(H) + H = H + damp * diag.mean() * torch.eye(in_dim, device=H.device, dtype=H.dtype) + + # Cholesky decomposition + try: + L = torch.linalg.cholesky(H) + Linv = torch.linalg.inv(L) + Hinv = Linv.T @ Linv + except RuntimeError: + # Fallback to naive if Cholesky fails + return quantize_intN_per_row(W, clip_range) + + Q = torch.zeros_like(W) + E = torch.zeros_like(W) + W_copy = W.clone() + + # Process columns sequentially (GPTQ algorithm) + for j in range(in_dim): + w_col = W_copy[:, j] + d = Hinv[j, j].clamp_min(1e-12) + q_col = torch.clamp(torch.round(w_col / scale_f), -(clip_range + 1), clip_range) + Q[:, j] = q_col + err = (w_col - q_col * scale_f) / d + E[:, j] = err + # Update remaining columns + if j + 1 < in_dim: + W_copy[:, j + 1:] -= err[:, None] * Hinv[j, j + 1:][None, :] + + return Q.to(torch.int8), scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + gptq_hessians: dict[str, Tensor] | None = None, gptq_damp: float = 0.01): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + # Build a mapping from state_dict weight keys to module names for GPTQ lookup + # e.g. "blocks.0.attn.c_q.weight" -> "blocks.0.attn.c_q" + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + # Try GPTQ if Hessian available + module_name = name.replace(".weight", "") if name.endswith(".weight") else None + if gptq_hessians and module_name and module_name in gptq_hessians and t.ndim == 2: + H = gptq_hessians[module_name] + q, s = gptq_quantize_layer(t, H, clip_range=clip, damp=gptq_damp) + else: + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: subtract self-value to force reliance on cross-token information + if self.use_xsa: + # Expand v for GQA if needed + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(rep, dim=1) + else: + v_expanded = v + # Self-attention weight for diagonal is 1/seq_len in expectation, + # but we subtract the actual self-value contribution: v_i / scale + # Simpler approach: subtract v_i projected orthogonally + y = y - v_expanded / seqlen + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: apply fake quantization via STE when lr_scale < threshold + qat_active = args.qat_threshold > 0 and scale < args.qat_threshold + if qat_active: + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + cat = _classify_param(name) + clip = 15 if cat == "mlp" else 31 # match int5/int6 + with torch.no_grad(): + t32 = param.float() + row_max = t32.abs().amax(dim=1).clamp_min(1e-12) + s = row_max / clip + q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) + dq = (q * s[:, None]).to(param.dtype) + # STE: replace param data but keep grad flowing + param.data.copy_(dq) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # GPTQ: collect Hessians for second-order quantization + gptq_hessians: dict[str, Tensor] = {} + if args.use_gptq: + log0(f"gptq:collecting Hessians (calib_batches={args.gptq_calib_batches})") + gptq_hooks = [] + gptq_inp_cache: dict[str, list[Tensor]] = {} + + def make_gptq_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + x_2d = x.detach().float().reshape(-1, x.shape[-1]) + if name not in gptq_inp_cache: + gptq_inp_cache[name] = [] + gptq_inp_cache[name].append(x_2d.cpu()) + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2 and module.weight.numel() > 65536: + gptq_hooks.append(module.register_forward_hook(make_gptq_hook(name))) + + base_model.eval() + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(args.gptq_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in gptq_hooks: + h.remove() + + # Build Hessians: H = X^T X / n_samples + for name, inp_list in gptq_inp_cache.items(): + X = torch.cat(inp_list, dim=0) + n = X.shape[0] + H = (X.T @ X) / n + gptq_hessians[name] = H + del gptq_inp_cache + log0(f"gptq:collected Hessians for {len(gptq_hessians)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, gptq_hessians=gptq_hessians, gptq_damp=args.gptq_damp) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb b/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb new file mode 100644 index 0000000000..749d7a1c22 --- /dev/null +++ b/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parameter Golf — Model Inference\n", + "\n", + "Load a trained model checkpoint and generate text.\n", + "\n", + "**Prerequisites**:\n", + "- A trained model (run `train_gpt.py` first)\n", + "- The experiment's `train_gpt.py` (for model class definitions)\n", + "- Tokenizer files in `data/tokenizers/`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp07_parallel-muon_from-exp06\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Import model classes from training script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the training script to get model architecture\n", + "sys.path.insert(0, EXPERIMENT_DIR)\n", + "import train_gpt as tg\n", + "\n", + "# Load hyperparameters\n", + "args = tg.Hyperparameters()\n", + "print(f\"Model config:\")\n", + "print(f\" vocab_size: {args.vocab_size}\")\n", + "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", + "print(f\" model_dim: {args.model_dim}\")\n", + "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", + "print(f\" mlp_mult: {args.mlp_mult}\")\n", + "print(f\" seq_len: {args.train_seq_len}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build model and load weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2b. (Alternative) Load from quantized artifact\n", + "\n", + "Use this if you only have the quantized `.ptz` file (the submission artifact)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.Load(TOKENIZER_PATH)\n", + "\n", + "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", + "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", + "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Generation utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate from a prompt\n", + "prompt = \"The history of artificial intelligence began\"\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=200,\n", + " temperature=0.8,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"\\nGenerated:\\n{output}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Try different prompts\n", + "prompts = [\n", + " \"In a small village by the sea, there lived\",\n", + " \"The most important scientific discovery of the 21st century\",\n", + " \"def fibonacci(n):\\n\",\n", + " \"Once upon a time\",\n", + "]\n", + "\n", + "for p in prompts:\n", + " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Prompt: {p}\")\n", + " print(f\"Output: {out}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Compute perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_texts = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", + " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", + "]\n", + "\n", + "for text in test_texts:\n", + " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", + " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Interactive generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive — change the prompt and re-run\n", + "prompt = \"Today I learned that\" # <-- Edit this!\n", + "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", + "max_tokens = 150\n", + "\n", + "output = generate(\n", + " model, sp, prompt,\n", + " max_new_tokens=max_tokens,\n", + " temperature=temperature,\n", + " top_k=50,\n", + " top_p=0.9,\n", + " device=DEVICE,\n", + " seq_len=args.train_seq_len,\n", + ")\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/records/phase3/exp07_parallel-muon_from-exp06/README.md b/records/phase3/exp07_parallel-muon_from-exp06/README.md new file mode 100644 index 0000000000..95ee7dda54 --- /dev/null +++ b/records/phase3/exp07_parallel-muon_from-exp06/README.md @@ -0,0 +1,26 @@ +# Parallel Muon (Batched Newton-Schulz) + +## Score: val_bpb = TBD + +## Hypothesis + +Grouping weight matrices by shape and running batched Newton-Schulz orthogonalization reduces per-step overhead. Community achieves 83ms/step with parameter banking. On 1×GPU, this is a minor speedup; on 8×H100, it's critical (more steps in 10 minutes = lower BPB). + +## Changes from exp06 + +- Added `zeropower_via_newtonschulz5_batched()` for 3D tensor [batch, rows, cols] +- Muon optimizer now groups params by shape, batches same-shape matrices for NS5 +- Momentum applied first to all grads, then batched NS5, then all-reduce +- Correctness: should produce identical val_bpb to exp06 on 1×GPU (same math, different execution order) + +## Architecture + +Inherits from exp06 (all features) + Parallel Muon optimizer. + +## Expected Impact + +Same val_bpb as exp06 on 1×GPU. Faster ms/step → more training steps within wallclock on multi-GPU. + +## Results + +TBD — awaiting A100 run. diff --git a/records/phase3/exp07_parallel-muon_from-exp06/logs.txt b/records/phase3/exp07_parallel-muon_from-exp06/logs.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/records/phase3/exp07_parallel-muon_from-exp06/run.sh b/records/phase3/exp07_parallel-muon_from-exp06/run.sh new file mode 100755 index 0000000000..673f205692 --- /dev/null +++ b/records/phase3/exp07_parallel-muon_from-exp06/run.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# ============================================================ +# Exp07: Parallel Muon (batched Newton-Schulz) — from Exp06 +# Groups weight matrices by shape for batched NS5 orthogonalization. +# Critical for 8xH100 throughput (83ms/step target). +# On 1xGPU: validates correctness, minor speedup from batching. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXP_NAME="exp07_parallel-muon_from-exp06" +cd /workspace/parameter-golf +export EVAL_STRIDE=0 +export SEED="${SEED:-42}" +export ITERATIONS="${ITERATIONS:-20000}" +export WARMUP_STEPS="${WARMUP_STEPS:-20}" +export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" +export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" +export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" +export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" +export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +export NUM_LAYERS=11 +export UNIQUE_LAYERS=10 +export MLP_ACTIVATION=leaky_relu_sq +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MOMENTUM_CYCLIC=1 +export MOMENTUM_MIN=0.85 +export MOMENTUM_MAX=0.95 +export MOMENTUM_CYCLE_PERIOD=50 +export AWQ_ENABLED=1 +export AWQ_ALPHA=0.5 +export ROPE_FRAC=0.25 +export EMA_ENABLED=1 +export EMA_DECAY=0.997 +export XSA_LAST_N=4 +export QAT_THRESHOLD=0.1 +export USE_GPTQ=1 +export GPTQ_DAMP=0.01 +export GPTQ_CALIB_BATCHES=8 +export RUN_ID="${EXP_NAME}" +echo "=== ${EXP_NAME} ===" +python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" +echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp07_parallel-muon_from-exp06/save_model.py b/records/phase3/exp07_parallel-muon_from-exp06/save_model.py new file mode 100644 index 0000000000..f8b8d261e2 --- /dev/null +++ b/records/phase3/exp07_parallel-muon_from-exp06/save_model.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Save a trained model checkpoint for exp07_parallel-muon_from-exp06.""" + +import argparse +import json +import os +import sys +import shutil + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-pt", type=str, default="final_model.pt") + parser.add_argument("--output-dir", type=str, default="model_checkpoint") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + import train_gpt as tg + + hp = tg.Hyperparameters() + config = { + "vocab_size": hp.vocab_size, + "num_layers": hp.num_layers, + "model_dim": hp.model_dim, + "num_heads": hp.num_heads, + "num_kv_heads": hp.num_kv_heads, + "mlp_mult": hp.mlp_mult, + "tie_embeddings": hp.tie_embeddings, + "tied_embed_init_std": hp.tied_embed_init_std, + "logit_softcap": hp.logit_softcap, + "rope_base": hp.rope_base, + "qk_gain_init": hp.qk_gain_init, + "bigram_vocab_size": hp.bigram_vocab_size, + "bigram_dim": hp.bigram_dim, + "unique_layers": hp.unique_layers, + "rope_frac": hp.rope_frac, + "ema_decay": hp.ema_decay, + "xsa_last_n": hp.xsa_last_n, + "qat_threshold": hp.qat_threshold, + "use_gptq": hp.use_gptq, + "train_seq_len": hp.train_seq_len, + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + if os.path.exists(args.model_pt): + shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) + quant_path = args.model_pt.replace(".pt", ".int8.ptz") + if os.path.exists(quant_path): + shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) + + print(f"Checkpoint saved to {args.output_dir}/") + +if __name__ == "__main__": + main() diff --git a/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py b/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py new file mode 100644 index 0000000000..ba77372e88 --- /dev/null +++ b/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py @@ -0,0 +1,1542 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) + use_gptq = bool(int(os.environ.get("USE_GPTQ", "1"))) + gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.01)) + gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 8)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) + momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) + momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) + momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + +# ----------------------------- +# MUON OPTIMIZER (Parallel / Batched) +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +def zeropower_via_newtonschulz5_batched(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz for 3D tensor [batch, rows, cols].""" + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + # Normalize each matrix in the batch + norms = X.flatten(1).norm(dim=1)[:, None, None].clamp_min(eps) + X = X / norms + transposed = X.size(1) > X.size(2) + if transposed: + X = X.transpose(1, 2) + for _ in range(steps): + A = X @ X.transpose(1, 2) + B = b * A + c * A @ A + X = a * X + B @ X + return X.transpose(1, 2) if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + # Group parameters by shape for batched Newton-Schulz + self._shape_groups: dict[tuple[int, int], list[int]] | None = None + + def _build_shape_groups(self, params: list) -> dict[tuple[int, int], list[int]]: + groups: dict[tuple[int, int], list[int]] = {} + for i, p in enumerate(params): + shape = (p.shape[0], p.shape[1]) + groups.setdefault(shape, []).append(i) + return groups + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + # Build shape groups on first call + if self._shape_groups is None: + self._shape_groups = self._build_shape_groups(params) + + # Apply momentum to all grads first + nesterov_grads: list[Tensor | None] = [None] * len(params) + for i, p in enumerate(params): + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + nesterov_grads[i] = g.add(buf, alpha=momentum) + else: + nesterov_grads[i] = buf.clone() + + # Batched Newton-Schulz per shape group + updates: list[Tensor | None] = [None] * len(params) + for shape, indices in self._shape_groups.items(): + # Filter to indices assigned to this rank (for distributed) and with grads + my_indices = [idx for idx in indices if idx % world_size == rank and nesterov_grads[idx] is not None] + if not my_indices: + continue + if len(my_indices) > 1: + # Batch: stack into 3D and process together + batch = torch.stack([nesterov_grads[idx] for idx in my_indices]) + batch_out = zeropower_via_newtonschulz5_batched(batch, steps=backend_steps) + scale = max(1, shape[0] / shape[1]) ** 0.5 + for j, idx in enumerate(my_indices): + updates[idx] = batch_out[j] * scale + else: + # Single matrix: use original function + idx = my_indices[0] + g = zeropower_via_newtonschulz5(nesterov_grads[idx], steps=backend_steps) + g *= max(1, shape[0] / shape[1]) ** 0.5 + updates[idx] = g + + # All-reduce updates + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if updates[i] is not None: + updates_flat[curr : curr + p.numel()] = updates[i].reshape(-1).bfloat16() + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + + +def gptq_quantize_layer(W: Tensor, H: Tensor, clip_range: int = 31, damp: float = 0.01) -> tuple[Tensor, Tensor]: + """GPTQ: second-order weight quantization using Hessian information. + W: [out_features, in_features] weight matrix + H: [in_features, in_features] Hessian (X^T X from calibration) + Returns (q_int8, scale_fp16) same as quantize_intN_per_row. + """ + W = W.float() + out_dim, in_dim = W.shape + # Per-row scale (same as naive) + row_max = W.abs().amax(dim=1).clamp_min(1e-12) + scale = (row_max / clip_range).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + scale_f = scale.float() + + # Dampen Hessian diagonal for numerical stability + diag = torch.diag(H) + H = H + damp * diag.mean() * torch.eye(in_dim, device=H.device, dtype=H.dtype) + + # Cholesky decomposition + try: + L = torch.linalg.cholesky(H) + Linv = torch.linalg.inv(L) + Hinv = Linv.T @ Linv + except RuntimeError: + # Fallback to naive if Cholesky fails + return quantize_intN_per_row(W, clip_range) + + Q = torch.zeros_like(W) + E = torch.zeros_like(W) + W_copy = W.clone() + + # Process columns sequentially (GPTQ algorithm) + for j in range(in_dim): + w_col = W_copy[:, j] + d = Hinv[j, j].clamp_min(1e-12) + q_col = torch.clamp(torch.round(w_col / scale_f), -(clip_range + 1), clip_range) + Q[:, j] = q_col + err = (w_col - q_col * scale_f) / d + E[:, j] = err + # Update remaining columns + if j + 1 < in_dim: + W_copy[:, j + 1:] -= err[:, None] * Hinv[j, j + 1:][None, :] + + return Q.to(torch.int8), scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + gptq_hessians: dict[str, Tensor] | None = None, gptq_damp: float = 0.01): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + # Build a mapping from state_dict weight keys to module names for GPTQ lookup + # e.g. "blocks.0.attn.c_q.weight" -> "blocks.0.attn.c_q" + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + # Try GPTQ if Hessian available + module_name = name.replace(".weight", "") if name.endswith(".weight") else None + if gptq_hessians and module_name and module_name in gptq_hessians and t.ndim == 2: + H = gptq_hessians[module_name] + q, s = gptq_quantize_layer(t, H, clip_range=clip, damp=gptq_damp) + else: + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_xsa = use_xsa + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + # Partial RoPE: apply rotary only to first rope_dims dimensions + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: subtract self-value to force reliance on cross-token information + if self.use_xsa: + # Expand v for GQA if needed + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + v_expanded = v.repeat_interleave(rep, dim=1) + else: + v_expanded = v + # Self-attention weight for diagonal is 1/seq_len in expectation, + # but we subtract the actual self-value contribution: v_i / scale + # Simpler approach: subtract v_i projected orthogonally + y = y - v_expanded / seqlen + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + unique_layers: int = 0, + rope_frac: float = 0.25, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them + n_unique = unique_layers if unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index → physical block index + # Last virtual layer shares with the last physical block + self._layer_map = list(range(min(n_unique, num_layers))) + while len(self._layer_map) < num_layers: + self._layer_map.append(n_unique - 1) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self._layer_map) # virtual depth, not physical blocks + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + unique_layers=args.unique_layers, + rope_frac=args.rope_frac, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: apply fake quantization via STE when lr_scale < threshold + qat_active = args.qat_threshold > 0 and scale < args.qat_threshold + if qat_active: + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + cat = _classify_param(name) + clip = 15 if cat == "mlp" else 31 # match int5/int6 + with torch.no_grad(): + t32 = param.float() + row_max = t32.abs().amax(dim=1).clamp_min(1e-12) + s = row_max / clip + q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) + dq = (q * s[:, None]).to(param.dtype) + # STE: replace param data but keep grad flowing + param.data.copy_(dq) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if step < args.muon_momentum_warmup_steps: + # Linear warmup phase + frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + elif args.momentum_cyclic: + # Cyclic momentum: triangle wave between momentum_min and momentum_max + period = args.momentum_cycle_period * 2 + pos = (step % period) / period + if pos < 0.5: + muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) + else: + muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) + else: + muon_momentum = args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: exponential moving average of weights every step + if args.ema_enabled and ema_state is not None: + decay = args.ema_decay + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights + if args.ema_enabled and ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_load = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_load, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # AWQ: Activation-aware weight scaling before quantization + awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) + awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) + if awq_enabled: + log0(f"awq:calibrating alpha={awq_alpha}") + # Step 1: Collect per-channel activation magnitudes via hooks + act_stats: dict[str, Tensor] = {} + hooks = [] + def make_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + # Mean absolute activation per channel (last dim) + s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) + if name in act_stats: + act_stats[name] = act_stats[name] + s + else: + act_stats[name] = s + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Step 2: Run calibration batches (use val data, 4 batches) + base_model.eval() + n_calib_batches = 4 + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(n_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in hooks: + h.remove() + + # Normalize activation stats + for name in act_stats: + act_stats[name] = act_stats[name] / n_calib_batches + + # Step 3: Scale weight columns by s^alpha + awq_scales = {} + with torch.no_grad(): + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: + s = act_stats[name].to(module.weight.device) + s = s.clamp_min(1e-6) + scale = s.pow(awq_alpha) + # Scale weight columns (input channels) + if module.weight.shape[1] == scale.shape[0]: + module.weight.data = module.weight.data * scale.unsqueeze(0) + awq_scales[name] = scale.cpu() + # Store inverse scale to apply to inputs at inference + # We'll fold this into the state dict + + log0(f"awq:scaled {len(awq_scales)} layers") + + # GPTQ: collect Hessians for second-order quantization + gptq_hessians: dict[str, Tensor] = {} + if args.use_gptq: + log0(f"gptq:collecting Hessians (calib_batches={args.gptq_calib_batches})") + gptq_hooks = [] + gptq_inp_cache: dict[str, list[Tensor]] = {} + + def make_gptq_hook(name): + def hook_fn(module, input, output): + x = input[0] if isinstance(input, tuple) else input + if x.ndim >= 2: + x_2d = x.detach().float().reshape(-1, x.shape[-1]) + if name not in gptq_inp_cache: + gptq_inp_cache[name] = [] + gptq_inp_cache[name].append(x_2d.cpu()) + return hook_fn + + for name, module in base_model.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2 and module.weight.numel() > 65536: + gptq_hooks.append(module.register_forward_hook(make_gptq_hook(name))) + + base_model.eval() + calib_seq_len = args.train_seq_len + calib_batch_seqs = 16 + calib_tokens_per_batch = calib_batch_seqs * calib_seq_len + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for cb in range(args.gptq_calib_batches): + start = cb * calib_tokens_per_batch + end = start + calib_tokens_per_batch + 1 + if end > val_tokens.numel(): + break + local = val_tokens[start:end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, calib_seq_len) + base_model.forward_logits(x) + + for h in gptq_hooks: + h.remove() + + # Build Hessians: H = X^T X / n_samples + for name, inp_list in gptq_inp_cache.items(): + X = torch.cat(inp_list, dim=0) + n = X.shape[0] + H = (X.T @ X) / n + gptq_hessians[name] = H + del gptq_inp_cache + log0(f"gptq:collected Hessians for {len(gptq_hessians)} layers") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Store AWQ inverse scales in the state dict for inference compensation + if awq_enabled and awq_scales: + for lname, scale in awq_scales.items(): + sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, gptq_hessians=gptq_hessians, gptq_damp=args.gptq_damp) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + # AWQ: undo column scaling after dequantization + if awq_enabled: + awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] + for inv_key in awq_inv_keys: + inv_scale = deq_state.pop(inv_key).float() + layer_name = inv_key.replace("_awq_inv_scale.", "") + weight_key = layer_name + ".weight" + if weight_key in deq_state: + # Undo: W_orig = W_scaled * inv_scale (per column) + deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) + deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) + log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") + # Remove any remaining AWQ keys before loading + deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned From ba67f53e633c88ea259342c216585a63905f2b44 Mon Sep 17 00:00:00 2001 From: SPThole Date: Mon, 13 Apr 2026 23:47:28 +0530 Subject: [PATCH 07/10] updt --- .../Inference.ipynb | 281 --- .../exp00_baseline-rerun_exp27/README.md | 22 - .../exp00_baseline-rerun_exp27/logs.txt | 0 .../phase3/exp00_baseline-rerun_exp27/run.sh | 38 - .../exp00_baseline-rerun_exp27/save_model.py | 52 - .../exp00_baseline-rerun_exp27/train_gpt.py | 1340 -------------- .../Inference.ipynb | 238 --- .../exp01_partial-rope_from-exp27/README.md | 33 - .../exp01_partial-rope_from-exp27/logs.txt | 0 .../exp01_partial-rope_from-exp27/run.sh | 40 - .../save_model.py | 68 - .../train_gpt.py | 1349 -------------- .../Inference.ipynb | 281 --- .../exp01b_ln-scale-only_from-exp27/README.md | 11 - .../exp01b_ln-scale-only_from-exp27/logs.txt | 0 .../exp01b_ln-scale-only_from-exp27/run.sh | 37 - .../save_model.py | 25 - .../train_gpt.py | 1341 -------------- .../Inference.ipynb | 281 --- .../exp01c_ema-only_from-exp27/README.md | 11 - .../exp01c_ema-only_from-exp27/logs.txt | 0 .../phase3/exp01c_ema-only_from-exp27/run.sh | 37 - .../exp01c_ema-only_from-exp27/save_model.py | 26 - .../exp01c_ema-only_from-exp27/train_gpt.py | 1335 -------------- .../Inference.ipynb | 281 --- .../exp01d_xsa-only_from-exp27/README.md | 11 - .../exp01d_xsa-only_from-exp27/logs.txt | 0 .../phase3/exp01d_xsa-only_from-exp27/run.sh | 38 - .../exp01d_xsa-only_from-exp27/save_model.py | 26 - .../exp01d_xsa-only_from-exp27/train_gpt.py | 1352 --------------- .../exp02_ln-scale_from-exp01/Inference.ipynb | 238 --- .../exp02_ln-scale_from-exp01/README.md | 24 - .../phase3/exp02_ln-scale_from-exp01/logs.txt | 0 .../phase3/exp02_ln-scale_from-exp01/run.sh | 39 - .../exp02_ln-scale_from-exp01/save_model.py | 53 - .../exp02_ln-scale_from-exp01/train_gpt.py | 1350 --------------- .../exp03_ema_from-exp02/Inference.ipynb | 238 --- records/phase3/exp03_ema_from-exp02/README.md | 26 - records/phase3/exp03_ema_from-exp02/logs.txt | 0 records/phase3/exp03_ema_from-exp02/run.sh | 39 - .../phase3/exp03_ema_from-exp02/save_model.py | 54 - .../phase3/exp03_ema_from-exp02/train_gpt.py | 1345 -------------- .../exp04_xsa4_from-exp03/Inference.ipynb | 238 --- .../phase3/exp04_xsa4_from-exp03/README.md | 28 - records/phase3/exp04_xsa4_from-exp03/logs.txt | 0 records/phase3/exp04_xsa4_from-exp03/run.sh | 41 - .../exp04_xsa4_from-exp03/save_model.py | 55 - .../phase3/exp04_xsa4_from-exp03/train_gpt.py | 1362 --------------- .../exp05_late-qat_from-exp04/Inference.ipynb | 238 --- .../exp05_late-qat_from-exp04/README.md | 26 - .../phase3/exp05_late-qat_from-exp04/logs.txt | 0 .../phase3/exp05_late-qat_from-exp04/run.sh | 42 - .../exp05_late-qat_from-exp04/save_model.py | 56 - .../exp05_late-qat_from-exp04/train_gpt.py | 1380 --------------- .../exp06_gptq_from-exp05/Inference.ipynb | 238 --- .../phase3/exp06_gptq_from-exp05/README.md | 26 - records/phase3/exp06_gptq_from-exp05/logs.txt | 0 records/phase3/exp06_gptq_from-exp05/run.sh | 45 - .../exp06_gptq_from-exp05/save_model.py | 57 - .../phase3/exp06_gptq_from-exp05/train_gpt.py | 1485 ---------------- .../Inference.ipynb | 238 --- .../exp07_parallel-muon_from-exp06/README.md | 26 - .../exp07_parallel-muon_from-exp06/logs.txt | 0 .../exp07_parallel-muon_from-exp06/run.sh | 45 - .../save_model.py | 57 - .../train_gpt.py | 1542 ----------------- 66 files changed, 19185 deletions(-) delete mode 100644 records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb delete mode 100644 records/phase3/exp00_baseline-rerun_exp27/README.md delete mode 100644 records/phase3/exp00_baseline-rerun_exp27/logs.txt delete mode 100755 records/phase3/exp00_baseline-rerun_exp27/run.sh delete mode 100644 records/phase3/exp00_baseline-rerun_exp27/save_model.py delete mode 100644 records/phase3/exp00_baseline-rerun_exp27/train_gpt.py delete mode 100644 records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb delete mode 100644 records/phase3/exp01_partial-rope_from-exp27/README.md delete mode 100644 records/phase3/exp01_partial-rope_from-exp27/logs.txt delete mode 100755 records/phase3/exp01_partial-rope_from-exp27/run.sh delete mode 100644 records/phase3/exp01_partial-rope_from-exp27/save_model.py delete mode 100644 records/phase3/exp01_partial-rope_from-exp27/train_gpt.py delete mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb delete mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/README.md delete mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt delete mode 100755 records/phase3/exp01b_ln-scale-only_from-exp27/run.sh delete mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py delete mode 100644 records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py delete mode 100644 records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb delete mode 100644 records/phase3/exp01c_ema-only_from-exp27/README.md delete mode 100644 records/phase3/exp01c_ema-only_from-exp27/logs.txt delete mode 100755 records/phase3/exp01c_ema-only_from-exp27/run.sh delete mode 100644 records/phase3/exp01c_ema-only_from-exp27/save_model.py delete mode 100644 records/phase3/exp01c_ema-only_from-exp27/train_gpt.py delete mode 100644 records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb delete mode 100644 records/phase3/exp01d_xsa-only_from-exp27/README.md delete mode 100644 records/phase3/exp01d_xsa-only_from-exp27/logs.txt delete mode 100755 records/phase3/exp01d_xsa-only_from-exp27/run.sh delete mode 100644 records/phase3/exp01d_xsa-only_from-exp27/save_model.py delete mode 100644 records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py delete mode 100644 records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb delete mode 100644 records/phase3/exp02_ln-scale_from-exp01/README.md delete mode 100644 records/phase3/exp02_ln-scale_from-exp01/logs.txt delete mode 100755 records/phase3/exp02_ln-scale_from-exp01/run.sh delete mode 100644 records/phase3/exp02_ln-scale_from-exp01/save_model.py delete mode 100644 records/phase3/exp02_ln-scale_from-exp01/train_gpt.py delete mode 100644 records/phase3/exp03_ema_from-exp02/Inference.ipynb delete mode 100644 records/phase3/exp03_ema_from-exp02/README.md delete mode 100644 records/phase3/exp03_ema_from-exp02/logs.txt delete mode 100755 records/phase3/exp03_ema_from-exp02/run.sh delete mode 100644 records/phase3/exp03_ema_from-exp02/save_model.py delete mode 100644 records/phase3/exp03_ema_from-exp02/train_gpt.py delete mode 100644 records/phase3/exp04_xsa4_from-exp03/Inference.ipynb delete mode 100644 records/phase3/exp04_xsa4_from-exp03/README.md delete mode 100644 records/phase3/exp04_xsa4_from-exp03/logs.txt delete mode 100755 records/phase3/exp04_xsa4_from-exp03/run.sh delete mode 100644 records/phase3/exp04_xsa4_from-exp03/save_model.py delete mode 100644 records/phase3/exp04_xsa4_from-exp03/train_gpt.py delete mode 100644 records/phase3/exp05_late-qat_from-exp04/Inference.ipynb delete mode 100644 records/phase3/exp05_late-qat_from-exp04/README.md delete mode 100644 records/phase3/exp05_late-qat_from-exp04/logs.txt delete mode 100755 records/phase3/exp05_late-qat_from-exp04/run.sh delete mode 100644 records/phase3/exp05_late-qat_from-exp04/save_model.py delete mode 100644 records/phase3/exp05_late-qat_from-exp04/train_gpt.py delete mode 100644 records/phase3/exp06_gptq_from-exp05/Inference.ipynb delete mode 100644 records/phase3/exp06_gptq_from-exp05/README.md delete mode 100644 records/phase3/exp06_gptq_from-exp05/logs.txt delete mode 100755 records/phase3/exp06_gptq_from-exp05/run.sh delete mode 100644 records/phase3/exp06_gptq_from-exp05/save_model.py delete mode 100644 records/phase3/exp06_gptq_from-exp05/train_gpt.py delete mode 100644 records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb delete mode 100644 records/phase3/exp07_parallel-muon_from-exp06/README.md delete mode 100644 records/phase3/exp07_parallel-muon_from-exp06/logs.txt delete mode 100755 records/phase3/exp07_parallel-muon_from-exp06/run.sh delete mode 100644 records/phase3/exp07_parallel-muon_from-exp06/save_model.py delete mode 100644 records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py diff --git a/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb b/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb deleted file mode 100644 index 8a406b185c..0000000000 --- a/records/phase3/exp00_baseline-rerun_exp27/Inference.ipynb +++ /dev/null @@ -1,281 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "import json\n", - "import io\n", - "import math\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "\n", - "# Config — change these paths as needed\n", - "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", - "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", - "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", - "\n", - "print(f\"Device: {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build model\n", - "model = tg.GPT(\n", - " vocab_size=args.vocab_size,\n", - " num_layers=args.num_layers,\n", - " model_dim=args.model_dim,\n", - " num_heads=args.num_heads,\n", - " num_kv_heads=args.num_kv_heads,\n", - " mlp_mult=args.mlp_mult,\n", - " tie_embeddings=args.tie_embeddings,\n", - " tied_embed_init_std=args.tied_embed_init_std,\n", - " logit_softcap=args.logit_softcap,\n", - " rope_base=args.rope_base,\n", - " qk_gain_init=args.qk_gain_init,\n", - " bigram_vocab_size=args.bigram_vocab_size,\n", - " bigram_dim=args.bigram_dim,\n", - " unique_layers=args.unique_layers,\n", - ")\n", - "\n", - "# Load state dict\n", - "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", - "model.load_state_dict(state_dict, strict=True)\n", - "model = model.to(DEVICE).eval()\n", - "\n", - "n_params = sum(p.numel() for p in model.parameters())\n", - "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp00_baseline-rerun_exp27/README.md b/records/phase3/exp00_baseline-rerun_exp27/README.md deleted file mode 100644 index b100a72f00..0000000000 --- a/records/phase3/exp00_baseline-rerun_exp27/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Baseline Rerun (Exp27 on A100) - -## Purpose - -Control experiment. Establishes the exact val_bpb of exp27 on A100 hardware with 2 seeds to measure variance. All phase3 experiments are compared against this. - -## Protocol - -Run twice: -```bash -SEED=42 bash run.sh -SEED=1337 bash run.sh -``` - -## Results - -| Seed | val_bpb | train_time | steps | -|------|---------|------------|-------| -| 42 | TBD | TBD | TBD | -| 1337 | TBD | TBD | TBD | -| **mean** | TBD | | | -| **std** | TBD | | | diff --git a/records/phase3/exp00_baseline-rerun_exp27/logs.txt b/records/phase3/exp00_baseline-rerun_exp27/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp00_baseline-rerun_exp27/run.sh b/records/phase3/exp00_baseline-rerun_exp27/run.sh deleted file mode 100755 index 9ce33c0470..0000000000 --- a/records/phase3/exp00_baseline-rerun_exp27/run.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp00: Baseline rerun of Exp27 on A100 -# Control experiment. Run with SEED=42 and SEED=1337 to -# establish variance baseline for all phase3 comparisons. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp00_baseline-rerun_exp27" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export RUN_ID="${EXP_NAME}_seed${SEED}" -echo "=== ${EXP_NAME} seed=${SEED} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp00_baseline-rerun_exp27/save_model.py b/records/phase3/exp00_baseline-rerun_exp27/save_model.py deleted file mode 100644 index 0188ff381c..0000000000 --- a/records/phase3/exp00_baseline-rerun_exp27/save_model.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp00_baseline-rerun_exp27.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py b/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py deleted file mode 100644 index 1b3a9d0087..0000000000 --- a/records/phase3/exp00_baseline-rerun_exp27/train_gpt.py +++ /dev/null @@ -1,1340 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb b/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb deleted file mode 100644 index cdb5ae3d87..0000000000 --- a/records/phase3/exp01_partial-rope_from-exp27/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\n# Config — change these paths as needed\nEXPERIMENT_DIR = \"records/phase3/exp01_partial-rope_from-exp27\"\nMODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp01_partial-rope_from-exp27/README.md b/records/phase3/exp01_partial-rope_from-exp27/README.md deleted file mode 100644 index edd4cb6204..0000000000 --- a/records/phase3/exp01_partial-rope_from-exp27/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Partial RoPE (25% of Head Dimensions) - -## Score: val_bpb = TBD - -## Hypothesis - -Applying RoPE to only 16/64 head dimensions allows the remaining 48 dimensions to learn position-independent semantic similarity. Community data shows ~0.005 BPB improvement. Zero additional parameters, zero compute overhead. - -## Change from exp27 - -Single architectural change in `CausalSelfAttention`: -- `rope_frac=0.25`: RoPE applied to first 16 dims, remaining 48 pass through unchanged -- `Rotary` module initialized with `dim=16` instead of `dim=64` - -## Architecture - -| Parameter | Value | -|-----------|-------| -| num_layers | 11 (10 unique) | -| model_dim | 512 | -| num_heads | 8, num_kv_heads | 4 | -| head_dim | 64 (16 RoPE + 48 pass-through) | -| mlp_mult | 3.0 (hidden=1536) | -| mlp_activation | LeakyReLU(0.5)² | -| rope_frac | 0.25 | - -## Expected Impact - -~0.005 BPB improvement over exp27 baseline (1.3345 → ~1.330) - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp01_partial-rope_from-exp27/logs.txt b/records/phase3/exp01_partial-rope_from-exp27/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp01_partial-rope_from-exp27/run.sh b/records/phase3/exp01_partial-rope_from-exp27/run.sh deleted file mode 100755 index 0e3274fbc3..0000000000 --- a/records/phase3/exp01_partial-rope_from-exp27/run.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp01: Partial RoPE (25% of head dims) — from Exp27 -# Apply RoPE to only 16/64 head dimensions. Remaining 48 dims -# attend without position encoding, learning semantic similarity -# independent of distance. Zero extra parameters. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp01_partial-rope_from-exp27" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01_partial-rope_from-exp27/save_model.py b/records/phase3/exp01_partial-rope_from-exp27/save_model.py deleted file mode 100644 index a978e39ce5..0000000000 --- a/records/phase3/exp01_partial-rope_from-exp27/save_model.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp01_partial-rope_from-exp27. - -Usage (after training): - python save_model.py [--model-pt final_model.pt] [--output-dir model_checkpoint] -""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - # Import training script for hyperparameters - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "train_seq_len": hp.train_seq_len, - } - - config_path = os.path.join(args.output_dir, "config.json") - with open(config_path, "w") as f: - json.dump(config, f, indent=2) - print(f"Saved config to {config_path}") - - if os.path.exists(args.model_pt): - dst = os.path.join(args.output_dir, "model.pt") - shutil.copy2(args.model_pt, dst) - print(f"Copied model to {dst}") - else: - print(f"Warning: {args.model_pt} not found. Run training first.") - - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - dst = os.path.join(args.output_dir, "model_quant.ptz") - shutil.copy2(quant_path, dst) - print(f"Copied quantized model to {dst}") - - print(f"\nCheckpoint saved to {args.output_dir}/") - - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py b/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py deleted file mode 100644 index cc3524c67e..0000000000 --- a/records/phase3/exp01_partial-rope_from-exp27/train_gpt.py +++ /dev/null @@ -1,1349 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb b/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb deleted file mode 100644 index 8a406b185c..0000000000 --- a/records/phase3/exp01b_ln-scale-only_from-exp27/Inference.ipynb +++ /dev/null @@ -1,281 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "import json\n", - "import io\n", - "import math\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "\n", - "# Config — change these paths as needed\n", - "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", - "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", - "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", - "\n", - "print(f\"Device: {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build model\n", - "model = tg.GPT(\n", - " vocab_size=args.vocab_size,\n", - " num_layers=args.num_layers,\n", - " model_dim=args.model_dim,\n", - " num_heads=args.num_heads,\n", - " num_kv_heads=args.num_kv_heads,\n", - " mlp_mult=args.mlp_mult,\n", - " tie_embeddings=args.tie_embeddings,\n", - " tied_embed_init_std=args.tied_embed_init_std,\n", - " logit_softcap=args.logit_softcap,\n", - " rope_base=args.rope_base,\n", - " qk_gain_init=args.qk_gain_init,\n", - " bigram_vocab_size=args.bigram_vocab_size,\n", - " bigram_dim=args.bigram_dim,\n", - " unique_layers=args.unique_layers,\n", - ")\n", - "\n", - "# Load state dict\n", - "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", - "model.load_state_dict(state_dict, strict=True)\n", - "model = model.to(DEVICE).eval()\n", - "\n", - "n_params = sum(p.numel() for p in model.parameters())\n", - "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/README.md b/records/phase3/exp01b_ln-scale-only_from-exp27/README.md deleted file mode 100644 index ecd371d8b8..0000000000 --- a/records/phase3/exp01b_ln-scale-only_from-exp27/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# LN Scale Only (Ablation) - -## Purpose - -Isolated ablation: measures the marginal effect of LN Scale (`1/√(layer+1)`) from exp27 baseline, without partial RoPE or any other changes. - -## Results - -| Seed | val_bpb | delta vs exp00 | -|------|---------|----------------| -| 42 | TBD | TBD | diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt b/records/phase3/exp01b_ln-scale-only_from-exp27/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh b/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh deleted file mode 100755 index ca163a3bd4..0000000000 --- a/records/phase3/exp01b_ln-scale-only_from-exp27/run.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp01b: LN Scale ONLY (ablation) — from Exp27 -# Isolated test of 1/√(layer+1) damping without partial RoPE. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp01b_ln-scale-only_from-exp27" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export RUN_ID="${EXP_NAME}_seed${SEED}" -echo "=== ${EXP_NAME} seed=${SEED} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py b/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py deleted file mode 100644 index c3b4fb7561..0000000000 --- a/records/phase3/exp01b_ln-scale-only_from-exp27/save_model.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp01b_ln-scale-only_from-exp27.""" -import argparse, json, os, sys, shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - os.makedirs(args.output_dir, exist_ok=True) - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - hp = tg.Hyperparameters() - config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py b/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py deleted file mode 100644 index ec0da3a940..0000000000 --- a/records/phase3/exp01b_ln-scale-only_from-exp27/train_gpt.py +++ /dev/null @@ -1,1341 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, layer_idx: int = 0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, layer_idx=i) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb b/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb deleted file mode 100644 index 8a406b185c..0000000000 --- a/records/phase3/exp01c_ema-only_from-exp27/Inference.ipynb +++ /dev/null @@ -1,281 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "import json\n", - "import io\n", - "import math\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "\n", - "# Config — change these paths as needed\n", - "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", - "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", - "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", - "\n", - "print(f\"Device: {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build model\n", - "model = tg.GPT(\n", - " vocab_size=args.vocab_size,\n", - " num_layers=args.num_layers,\n", - " model_dim=args.model_dim,\n", - " num_heads=args.num_heads,\n", - " num_kv_heads=args.num_kv_heads,\n", - " mlp_mult=args.mlp_mult,\n", - " tie_embeddings=args.tie_embeddings,\n", - " tied_embed_init_std=args.tied_embed_init_std,\n", - " logit_softcap=args.logit_softcap,\n", - " rope_base=args.rope_base,\n", - " qk_gain_init=args.qk_gain_init,\n", - " bigram_vocab_size=args.bigram_vocab_size,\n", - " bigram_dim=args.bigram_dim,\n", - " unique_layers=args.unique_layers,\n", - ")\n", - "\n", - "# Load state dict\n", - "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", - "model.load_state_dict(state_dict, strict=True)\n", - "model = model.to(DEVICE).eval()\n", - "\n", - "n_params = sum(p.numel() for p in model.parameters())\n", - "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp01c_ema-only_from-exp27/README.md b/records/phase3/exp01c_ema-only_from-exp27/README.md deleted file mode 100644 index 47f860aad8..0000000000 --- a/records/phase3/exp01c_ema-only_from-exp27/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# EMA Only (Ablation) - -## Purpose - -Isolated ablation: measures the marginal effect of EMA (decay=0.997) replacing SWA, from exp27 baseline, without partial RoPE or LN Scale. - -## Results - -| Seed | val_bpb | delta vs exp00 | -|------|---------|----------------| -| 42 | TBD | TBD | diff --git a/records/phase3/exp01c_ema-only_from-exp27/logs.txt b/records/phase3/exp01c_ema-only_from-exp27/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp01c_ema-only_from-exp27/run.sh b/records/phase3/exp01c_ema-only_from-exp27/run.sh deleted file mode 100755 index 6ed223a78e..0000000000 --- a/records/phase3/exp01c_ema-only_from-exp27/run.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp01c: EMA ONLY (ablation) — from Exp27 -# Isolated test of EMA (decay=0.997) replacing SWA. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp01c_ema-only_from-exp27" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export RUN_ID="${EXP_NAME}_seed${SEED}" -echo "=== ${EXP_NAME} seed=${SEED} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01c_ema-only_from-exp27/save_model.py b/records/phase3/exp01c_ema-only_from-exp27/save_model.py deleted file mode 100644 index f28ff2c12e..0000000000 --- a/records/phase3/exp01c_ema-only_from-exp27/save_model.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp01c_ema-only_from-exp27.""" -import argparse, json, os, sys, shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - os.makedirs(args.output_dir, exist_ok=True) - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - hp = tg.Hyperparameters() - config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} - config["ema_decay"] = hp.ema_decay - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py b/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py deleted file mode 100644 index 3bf2f75f82..0000000000 --- a/records/phase3/exp01c_ema-only_from-exp27/train_gpt.py +++ /dev/null @@ -1,1335 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb b/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb deleted file mode 100644 index 8a406b185c..0000000000 --- a/records/phase3/exp01d_xsa-only_from-exp27/Inference.ipynb +++ /dev/null @@ -1,281 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "import json\n", - "import io\n", - "import math\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "\n", - "# Config — change these paths as needed\n", - "EXPERIMENT_DIR = \"records/phase2/exp27_leaky-relu-sq_from-exp18\"\n", - "MODEL_PATH = \"final_model.pt\" # or path to saved checkpoint\n", - "TOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", - "\n", - "print(f\"Device: {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build model\n", - "model = tg.GPT(\n", - " vocab_size=args.vocab_size,\n", - " num_layers=args.num_layers,\n", - " model_dim=args.model_dim,\n", - " num_heads=args.num_heads,\n", - " num_kv_heads=args.num_kv_heads,\n", - " mlp_mult=args.mlp_mult,\n", - " tie_embeddings=args.tie_embeddings,\n", - " tied_embed_init_std=args.tied_embed_init_std,\n", - " logit_softcap=args.logit_softcap,\n", - " rope_base=args.rope_base,\n", - " qk_gain_init=args.qk_gain_init,\n", - " bigram_vocab_size=args.bigram_vocab_size,\n", - " bigram_dim=args.bigram_dim,\n", - " unique_layers=args.unique_layers,\n", - ")\n", - "\n", - "# Load state dict\n", - "state_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\n", - "model.load_state_dict(state_dict, strict=True)\n", - "model = model.to(DEVICE).eval()\n", - "\n", - "n_params = sum(p.numel() for p in model.parameters())\n", - "print(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp01d_xsa-only_from-exp27/README.md b/records/phase3/exp01d_xsa-only_from-exp27/README.md deleted file mode 100644 index 8f028ed8ad..0000000000 --- a/records/phase3/exp01d_xsa-only_from-exp27/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# XSA Only (Ablation) - -## Purpose - -Isolated ablation: measures the marginal effect of XSA on last 4 layers from exp27 baseline, keeping SWA and full RoPE. Critical to test since community says XSA without EMA can hurt (-0.003 BPB). - -## Results - -| Seed | val_bpb | delta vs exp00 | -|------|---------|----------------| -| 42 | TBD | TBD | diff --git a/records/phase3/exp01d_xsa-only_from-exp27/logs.txt b/records/phase3/exp01d_xsa-only_from-exp27/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp01d_xsa-only_from-exp27/run.sh b/records/phase3/exp01d_xsa-only_from-exp27/run.sh deleted file mode 100755 index 952bb8b9a8..0000000000 --- a/records/phase3/exp01d_xsa-only_from-exp27/run.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp01d: XSA ONLY (ablation) — from Exp27 -# Isolated test of XSA on last 4 layers, keeping SWA and full RoPE. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp01d_xsa-only_from-exp27" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export XSA_LAST_N=4 -export RUN_ID="${EXP_NAME}_seed${SEED}" -echo "=== ${EXP_NAME} seed=${SEED} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs_seed${SEED}.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp01d_xsa-only_from-exp27/save_model.py b/records/phase3/exp01d_xsa-only_from-exp27/save_model.py deleted file mode 100644 index 68f98f6849..0000000000 --- a/records/phase3/exp01d_xsa-only_from-exp27/save_model.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp01d_xsa-only_from-exp27.""" -import argparse, json, os, sys, shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - os.makedirs(args.output_dir, exist_ok=True) - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - hp = tg.Hyperparameters() - config = {k: getattr(hp, k) for k in ["vocab_size","num_layers","model_dim","num_heads","num_kv_heads","mlp_mult","tie_embeddings","tied_embed_init_std","logit_softcap","rope_base","qk_gain_init","bigram_vocab_size","bigram_dim","unique_layers","train_seq_len"]} - config["xsa_last_n"] = hp.xsa_last_n - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py b/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py deleted file mode 100644 index 84aa32d278..0000000000 --- a/records/phase3/exp01d_xsa-only_from-exp27/train_gpt.py +++ /dev/null @@ -1,1352 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, use_xsa: bool = False): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.use_xsa = use_xsa - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - if self.use_xsa: - if self.num_kv_heads != self.num_heads: - rep = self.num_heads // self.num_kv_heads - v_expanded = v.repeat_interleave(rep, dim=1) - else: - v_expanded = v - y = y - v_expanded / seqlen - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, use_xsa: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa=use_xsa) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - use_xsa=(i >= n_unique - xsa_last_n)) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb b/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb deleted file mode 100644 index 64d060f768..0000000000 --- a/records/phase3/exp02_ln-scale_from-exp01/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp02_ln-scale_from-exp01\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp02_ln-scale_from-exp01/README.md b/records/phase3/exp02_ln-scale_from-exp01/README.md deleted file mode 100644 index aa46634959..0000000000 --- a/records/phase3/exp02_ln-scale_from-exp01/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# LN Scale (1/√(layer+1)) Damping - -## Score: val_bpb = TBD - -## Hypothesis - -Post-RMSNorm output scaling by `1/√(layer_idx+1)` damps deeper layer contributions, preventing later layers from overwriting early representations. Community data shows ~0.003 BPB improvement. Zero additional parameters. - -## Changes from exp01 - -- `Block.__init__` now takes `layer_idx`, computes `self.ln_scale = 1/√(layer_idx+1)` -- `Block.forward` multiplies both `attn_scale` and `mlp_scale` residuals by `ln_scale` - -## Architecture - -Inherits from exp01 (Partial RoPE 25%) + adds LN Scale damping. - -## Expected Impact - -~0.003 BPB improvement over exp01. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp02_ln-scale_from-exp01/logs.txt b/records/phase3/exp02_ln-scale_from-exp01/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp02_ln-scale_from-exp01/run.sh b/records/phase3/exp02_ln-scale_from-exp01/run.sh deleted file mode 100755 index c45fcb6c91..0000000000 --- a/records/phase3/exp02_ln-scale_from-exp01/run.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp02: LN Scale (1/√(layer+1)) — from Exp01 -# Damps deeper layer contributions to prevent overwriting early -# representations. Multiplies attn/mlp residual by 1/√(layer+1). -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp02_ln-scale_from-exp01" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp02_ln-scale_from-exp01/save_model.py b/records/phase3/exp02_ln-scale_from-exp01/save_model.py deleted file mode 100644 index 6f5093b292..0000000000 --- a/records/phase3/exp02_ln-scale_from-exp01/save_model.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp02_ln-scale_from-exp01.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py b/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py deleted file mode 100644 index 178d3c7b8d..0000000000 --- a/records/phase3/exp02_ln-scale_from-exp01/train_gpt.py +++ /dev/null @@ -1,1350 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac, layer_idx=i) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp03_ema_from-exp02/Inference.ipynb b/records/phase3/exp03_ema_from-exp02/Inference.ipynb deleted file mode 100644 index 9d1e51e217..0000000000 --- a/records/phase3/exp03_ema_from-exp02/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp03_ema_from-exp02\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp03_ema_from-exp02/README.md b/records/phase3/exp03_ema_from-exp02/README.md deleted file mode 100644 index d2ab07c469..0000000000 --- a/records/phase3/exp03_ema_from-exp02/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# EMA (Exponential Moving Average) Replacing SWA - -## Score: val_bpb = TBD - -## Hypothesis - -EMA with decay=0.997 outperforms SWA by ~0.003 BPB (community verified, 3-seed). EMA updates every step (smoother averaging) vs SWA's periodic snapshots. Critical prerequisite for XSA (exp04). - -## Changes from exp02 - -- Removed SWA hyperparameters (`swa_enabled`, `swa_start_frac`, `swa_every`) -- Added `ema_enabled=True`, `ema_decay=0.997` -- EMA shadow weights updated every training step: `ema = decay * ema + (1-decay) * weights` -- EMA weights loaded before final eval (replaces SWA averaging) - -## Architecture - -Inherits from exp02 (Partial RoPE + LN Scale) + EMA replacing SWA. - -## Expected Impact - -~0.003 BPB improvement over exp02. Also unlocks synergy with XSA. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp03_ema_from-exp02/logs.txt b/records/phase3/exp03_ema_from-exp02/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp03_ema_from-exp02/run.sh b/records/phase3/exp03_ema_from-exp02/run.sh deleted file mode 100755 index 37aa1a4360..0000000000 --- a/records/phase3/exp03_ema_from-exp02/run.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp03: EMA (decay=0.997) replacing SWA — from Exp02 -# Exponential moving average of weights updated every step. -# Community confirmed EMA > SWA by 0.003 BPB. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp03_ema_from-exp02" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp03_ema_from-exp02/save_model.py b/records/phase3/exp03_ema_from-exp02/save_model.py deleted file mode 100644 index b8220e3172..0000000000 --- a/records/phase3/exp03_ema_from-exp02/save_model.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp03_ema_from-exp02.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "ema_decay": hp.ema_decay, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp03_ema_from-exp02/train_gpt.py b/records/phase3/exp03_ema_from-exp02/train_gpt.py deleted file mode 100644 index ce1c19fea2..0000000000 --- a/records/phase3/exp03_ema_from-exp02/train_gpt.py +++ /dev/null @@ -1,1345 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_frac=rope_frac, layer_idx=i) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb b/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb deleted file mode 100644 index dec1276ff7..0000000000 --- a/records/phase3/exp04_xsa4_from-exp03/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp04_xsa4_from-exp03\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp04_xsa4_from-exp03/README.md b/records/phase3/exp04_xsa4_from-exp03/README.md deleted file mode 100644 index 3d051e28d8..0000000000 --- a/records/phase3/exp04_xsa4_from-exp03/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Exclusive Self-Attention (XSA) on Last 4 Layers - -## Score: val_bpb = TBD - -## Hypothesis - -XSA removes the self-value bias from attention output by subtracting `v_i / seq_len` from each position. This forces the model to rely on information from other tokens rather than copying its own value. Applied to last 4 of 10 unique layers (not all). Zero additional parameters, slight compute overhead. - -Community shows XSA + EMA is the "prerequisite stack" for all frontier techniques. EMA without XSA loses 0.003 BPB; EMA with XSA gains 0.003 BPB. - -## Changes from exp03 - -- `CausalSelfAttention` gains `use_xsa` flag -- When enabled, subtracts `v_expanded / seqlen` after SDPA (removes self-value contribution) -- `Block` passes `use_xsa` to attention -- `GPT` enables XSA on last `xsa_last_n=4` physical blocks - -## Architecture - -Inherits from exp03 (Partial RoPE + LN Scale + EMA) + XSA on last 4 layers. - -## Expected Impact - -~0.01-0.02 BPB improvement over exp03. Also synergizes with EMA. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp04_xsa4_from-exp03/logs.txt b/records/phase3/exp04_xsa4_from-exp03/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp04_xsa4_from-exp03/run.sh b/records/phase3/exp04_xsa4_from-exp03/run.sh deleted file mode 100755 index d673b67919..0000000000 --- a/records/phase3/exp04_xsa4_from-exp03/run.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp04: XSA on last 4 layers — from Exp03 -# Exclusive Self-Attention: subtract self-value from attention -# output, forcing reliance on cross-token information. -# Critical synergy with EMA (exp03). -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp04_xsa4_from-exp03" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export XSA_LAST_N=4 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp04_xsa4_from-exp03/save_model.py b/records/phase3/exp04_xsa4_from-exp03/save_model.py deleted file mode 100644 index 9d7cd8b817..0000000000 --- a/records/phase3/exp04_xsa4_from-exp03/save_model.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp04_xsa4_from-exp03.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "ema_decay": hp.ema_decay, - "xsa_last_n": hp.xsa_last_n, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp04_xsa4_from-exp03/train_gpt.py b/records/phase3/exp04_xsa4_from-exp03/train_gpt.py deleted file mode 100644 index ee05fcd60f..0000000000 --- a/records/phase3/exp04_xsa4_from-exp03/train_gpt.py +++ /dev/null @@ -1,1362 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.use_xsa = use_xsa - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - # XSA: subtract self-value to force reliance on cross-token information - if self.use_xsa: - # Expand v for GQA if needed - if self.num_kv_heads != self.num_heads: - rep = self.num_heads // self.num_kv_heads - v_expanded = v.repeat_interleave(rep, dim=1) - else: - v_expanded = v - # Self-attention weight for diagonal is 1/seq_len in expectation, - # but we subtract the actual self-value contribution: v_i / scale - # Simpler approach: subtract v_i projected orthogonally - y = y - v_expanded / seqlen - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb b/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb deleted file mode 100644 index badc853e52..0000000000 --- a/records/phase3/exp05_late-qat_from-exp04/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp05_late-qat_from-exp04\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp05_late-qat_from-exp04/README.md b/records/phase3/exp05_late-qat_from-exp04/README.md deleted file mode 100644 index c398acb06a..0000000000 --- a/records/phase3/exp05_late-qat_from-exp04/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# Late Quantization-Aware Training (QAT) - -## Score: val_bpb = TBD - -## Hypothesis - -Applying fake quantization (quantize → dequantize with STE) only when `lr_scale < 0.1` (final ~4% of training) lets the model learn robust quantized configurations without corrupting early convergence. Community shows this closes the quantization gap from ~0.023 to ~0.007 BPB. - -## Changes from exp04 - -- Added `qat_threshold=0.1` hyperparameter -- Before each forward pass during warmdown (when `lr_scale < threshold`): fake-quantize all 2D weight matrices using the same int5/int6 clip ranges as post-training quantization -- Uses `_classify_param` to match int5 (MLP) vs int6 (attn) clip ranges -- STE: gradient flows through the quantized weights unchanged - -## Architecture - -Inherits from exp04 (Partial RoPE + LN Scale + EMA + XSA) + Late QAT. - -## Expected Impact - -Reduce quantization penalty from ~0.02 to ~0.007 BPB. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp05_late-qat_from-exp04/logs.txt b/records/phase3/exp05_late-qat_from-exp04/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp05_late-qat_from-exp04/run.sh b/records/phase3/exp05_late-qat_from-exp04/run.sh deleted file mode 100755 index cef8bfe0c1..0000000000 --- a/records/phase3/exp05_late-qat_from-exp04/run.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp05: Late QAT (STE when lr_scale < 0.1) — from Exp04 -# Quantization-Aware Training via Straight-Through Estimator. -# Activates only in final ~4% of training (warmdown tail). -# Closes quantization gap from ~0.023 to ~0.007 BPB. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp05_late-qat_from-exp04" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export XSA_LAST_N=4 -export QAT_THRESHOLD=0.1 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp05_late-qat_from-exp04/save_model.py b/records/phase3/exp05_late-qat_from-exp04/save_model.py deleted file mode 100644 index d82d838f9e..0000000000 --- a/records/phase3/exp05_late-qat_from-exp04/save_model.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp05_late-qat_from-exp04.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "ema_decay": hp.ema_decay, - "xsa_last_n": hp.xsa_last_n, - "qat_threshold": hp.qat_threshold, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp05_late-qat_from-exp04/train_gpt.py b/records/phase3/exp05_late-qat_from-exp04/train_gpt.py deleted file mode 100644 index 26fdc29926..0000000000 --- a/records/phase3/exp05_late-qat_from-exp04/train_gpt.py +++ /dev/null @@ -1,1380 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.use_xsa = use_xsa - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - # XSA: subtract self-value to force reliance on cross-token information - if self.use_xsa: - # Expand v for GQA if needed - if self.num_kv_heads != self.num_heads: - rep = self.num_heads // self.num_kv_heads - v_expanded = v.repeat_interleave(rep, dim=1) - else: - v_expanded = v - # Self-attention weight for diagonal is 1/seq_len in expectation, - # but we subtract the actual self-value contribution: v_i / scale - # Simpler approach: subtract v_i projected orthogonally - y = y - v_expanded / seqlen - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Late QAT: apply fake quantization via STE when lr_scale < threshold - qat_active = args.qat_threshold > 0 and scale < args.qat_threshold - if qat_active: - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - cat = _classify_param(name) - clip = 15 if cat == "mlp" else 31 # match int5/int6 - with torch.no_grad(): - t32 = param.float() - row_max = t32.abs().amax(dim=1).clamp_min(1e-12) - s = row_max / clip - q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) - dq = (q * s[:, None]).to(param.dtype) - # STE: replace param data but keep grad flowing - param.data.copy_(dq) - - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp06_gptq_from-exp05/Inference.ipynb b/records/phase3/exp06_gptq_from-exp05/Inference.ipynb deleted file mode 100644 index 7874850879..0000000000 --- a/records/phase3/exp06_gptq_from-exp05/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp06_gptq_from-exp05\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp06_gptq_from-exp05/README.md b/records/phase3/exp06_gptq_from-exp05/README.md deleted file mode 100644 index 55ca35476f..0000000000 --- a/records/phase3/exp06_gptq_from-exp05/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# GPTQ Second-Order Quantization - -## Score: val_bpb = TBD - -## Hypothesis - -GPTQ uses Hessian information (X^T X from calibration data) to minimize quantization error via column-by-column second-order optimization. Community shows full GPTQ at 1.1154 vs lite at 1.1228 — ~0.007 BPB improvement over naive quantization. - -## Changes from exp05 - -- Added `gptq_quantize_layer()`: Cholesky-based GPTQ with per-row scaling -- Hessian collection via forward hooks on calibration data (8 batches) -- `mixed_quantize_int6` now accepts `gptq_hessians` dict — uses GPTQ when Hessian available, falls back to naive otherwise -- AWQ + GPTQ stacked (AWQ scales columns before GPTQ quantizes) - -## Architecture - -Inherits from exp05 (Partial RoPE + LN Scale + EMA + XSA + Late QAT) + GPTQ post-training. - -## Expected Impact - -~0.007 BPB improvement in quantized model quality. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp06_gptq_from-exp05/logs.txt b/records/phase3/exp06_gptq_from-exp05/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp06_gptq_from-exp05/run.sh b/records/phase3/exp06_gptq_from-exp05/run.sh deleted file mode 100755 index 8c7537b451..0000000000 --- a/records/phase3/exp06_gptq_from-exp05/run.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp06: GPTQ second-order quantization — from Exp05 -# Replaces naive per-row quantization with Hessian-aware GPTQ. -# Uses calibration data to minimize quantization error via -# second-order column-by-column optimization. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp06_gptq_from-exp05" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export XSA_LAST_N=4 -export QAT_THRESHOLD=0.1 -export USE_GPTQ=1 -export GPTQ_DAMP=0.01 -export GPTQ_CALIB_BATCHES=8 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp06_gptq_from-exp05/save_model.py b/records/phase3/exp06_gptq_from-exp05/save_model.py deleted file mode 100644 index e5abe53bfe..0000000000 --- a/records/phase3/exp06_gptq_from-exp05/save_model.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp06_gptq_from-exp05.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "ema_decay": hp.ema_decay, - "xsa_last_n": hp.xsa_last_n, - "qat_threshold": hp.qat_threshold, - "use_gptq": hp.use_gptq, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp06_gptq_from-exp05/train_gpt.py b/records/phase3/exp06_gptq_from-exp05/train_gpt.py deleted file mode 100644 index bcb506608a..0000000000 --- a/records/phase3/exp06_gptq_from-exp05/train_gpt.py +++ /dev/null @@ -1,1485 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) - use_gptq = bool(int(os.environ.get("USE_GPTQ", "1"))) - gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.01)) - gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 8)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - - -def gptq_quantize_layer(W: Tensor, H: Tensor, clip_range: int = 31, damp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ: second-order weight quantization using Hessian information. - W: [out_features, in_features] weight matrix - H: [in_features, in_features] Hessian (X^T X from calibration) - Returns (q_int8, scale_fp16) same as quantize_intN_per_row. - """ - W = W.float() - out_dim, in_dim = W.shape - # Per-row scale (same as naive) - row_max = W.abs().amax(dim=1).clamp_min(1e-12) - scale = (row_max / clip_range).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - scale_f = scale.float() - - # Dampen Hessian diagonal for numerical stability - diag = torch.diag(H) - H = H + damp * diag.mean() * torch.eye(in_dim, device=H.device, dtype=H.dtype) - - # Cholesky decomposition - try: - L = torch.linalg.cholesky(H) - Linv = torch.linalg.inv(L) - Hinv = Linv.T @ Linv - except RuntimeError: - # Fallback to naive if Cholesky fails - return quantize_intN_per_row(W, clip_range) - - Q = torch.zeros_like(W) - E = torch.zeros_like(W) - W_copy = W.clone() - - # Process columns sequentially (GPTQ algorithm) - for j in range(in_dim): - w_col = W_copy[:, j] - d = Hinv[j, j].clamp_min(1e-12) - q_col = torch.clamp(torch.round(w_col / scale_f), -(clip_range + 1), clip_range) - Q[:, j] = q_col - err = (w_col - q_col * scale_f) / d - E[:, j] = err - # Update remaining columns - if j + 1 < in_dim: - W_copy[:, j + 1:] -= err[:, None] * Hinv[j, j + 1:][None, :] - - return Q.to(torch.int8), scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], - gptq_hessians: dict[str, Tensor] | None = None, gptq_damp: float = 0.01): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - # Build a mapping from state_dict weight keys to module names for GPTQ lookup - # e.g. "blocks.0.attn.c_q.weight" -> "blocks.0.attn.c_q" - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - # Try GPTQ if Hessian available - module_name = name.replace(".weight", "") if name.endswith(".weight") else None - if gptq_hessians and module_name and module_name in gptq_hessians and t.ndim == 2: - H = gptq_hessians[module_name] - q, s = gptq_quantize_layer(t, H, clip_range=clip, damp=gptq_damp) - else: - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.use_xsa = use_xsa - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - # XSA: subtract self-value to force reliance on cross-token information - if self.use_xsa: - # Expand v for GQA if needed - if self.num_kv_heads != self.num_heads: - rep = self.num_heads // self.num_kv_heads - v_expanded = v.repeat_interleave(rep, dim=1) - else: - v_expanded = v - # Self-attention weight for diagonal is 1/seq_len in expectation, - # but we subtract the actual self-value contribution: v_i / scale - # Simpler approach: subtract v_i projected orthogonally - y = y - v_expanded / seqlen - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Late QAT: apply fake quantization via STE when lr_scale < threshold - qat_active = args.qat_threshold > 0 and scale < args.qat_threshold - if qat_active: - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - cat = _classify_param(name) - clip = 15 if cat == "mlp" else 31 # match int5/int6 - with torch.no_grad(): - t32 = param.float() - row_max = t32.abs().amax(dim=1).clamp_min(1e-12) - s = row_max / clip - q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) - dq = (q * s[:, None]).to(param.dtype) - # STE: replace param data but keep grad flowing - param.data.copy_(dq) - - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # GPTQ: collect Hessians for second-order quantization - gptq_hessians: dict[str, Tensor] = {} - if args.use_gptq: - log0(f"gptq:collecting Hessians (calib_batches={args.gptq_calib_batches})") - gptq_hooks = [] - gptq_inp_cache: dict[str, list[Tensor]] = {} - - def make_gptq_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - x_2d = x.detach().float().reshape(-1, x.shape[-1]) - if name not in gptq_inp_cache: - gptq_inp_cache[name] = [] - gptq_inp_cache[name].append(x_2d.cpu()) - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2 and module.weight.numel() > 65536: - gptq_hooks.append(module.register_forward_hook(make_gptq_hook(name))) - - base_model.eval() - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(args.gptq_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in gptq_hooks: - h.remove() - - # Build Hessians: H = X^T X / n_samples - for name, inp_list in gptq_inp_cache.items(): - X = torch.cat(inp_list, dim=0) - n = X.shape[0] - H = (X.T @ X) / n - gptq_hessians[name] = H - del gptq_inp_cache - log0(f"gptq:collected Hessians for {len(gptq_hessians)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, gptq_hessians=gptq_hessians, gptq_damp=args.gptq_damp) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb b/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb deleted file mode 100644 index 749d7a1c22..0000000000 --- a/records/phase3/exp07_parallel-muon_from-exp06/Inference.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Parameter Golf — Model Inference\n", - "\n", - "Load a trained model checkpoint and generate text.\n", - "\n", - "**Prerequisites**:\n", - "- A trained model (run `train_gpt.py` first)\n", - "- The experiment's `train_gpt.py` (for model class definitions)\n", - "- Tokenizer files in `data/tokenizers/`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "import sys\nimport os\nimport json\nimport io\nimport math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nEXPERIMENT_DIR = \"records/phase3/exp07_parallel-muon_from-exp06\"\nMODEL_PATH = \"final_model.pt\"\nTOKENIZER_PATH = \"data/tokenizers/fineweb_1024_bpe.model\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\nprint(f\"Device: {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Import model classes from training script" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the training script to get model architecture\n", - "sys.path.insert(0, EXPERIMENT_DIR)\n", - "import train_gpt as tg\n", - "\n", - "# Load hyperparameters\n", - "args = tg.Hyperparameters()\n", - "print(f\"Model config:\")\n", - "print(f\" vocab_size: {args.vocab_size}\")\n", - "print(f\" num_layers: {args.num_layers} (unique: {args.unique_layers})\")\n", - "print(f\" model_dim: {args.model_dim}\")\n", - "print(f\" num_heads: {args.num_heads}, num_kv_heads: {args.num_kv_heads}\")\n", - "print(f\" mlp_mult: {args.mlp_mult}\")\n", - "print(f\" seq_len: {args.train_seq_len}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Build model and load weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "# Build model\nmodel = tg.GPT(\n vocab_size=args.vocab_size,\n num_layers=args.num_layers,\n model_dim=args.model_dim,\n num_heads=args.num_heads,\n num_kv_heads=args.num_kv_heads,\n mlp_mult=args.mlp_mult,\n tie_embeddings=args.tie_embeddings,\n tied_embed_init_std=args.tied_embed_init_std,\n logit_softcap=args.logit_softcap,\n rope_base=args.rope_base,\n qk_gain_init=args.qk_gain_init,\n bigram_vocab_size=args.bigram_vocab_size,\n bigram_dim=args.bigram_dim,\n unique_layers=args.unique_layers,\n rope_frac=args.rope_frac,\n xsa_last_n=args.xsa_last_n,\n)\n\n# Load state dict\nstate_dict = torch.load(MODEL_PATH, map_location=\"cpu\")\nmodel.load_state_dict(state_dict, strict=True)\nmodel = model.to(DEVICE).eval()\n\nn_params = sum(p.numel() for p in model.parameters())\nprint(f\"Model loaded: {n_params:,} parameters on {DEVICE}\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2b. (Alternative) Load from quantized artifact\n", - "\n", - "Use this if you only have the quantized `.ptz` file (the submission artifact)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "## Load from quantized artifact (.ptz)\n## Make sure you've run cell 2 (model build) first — we need the model structure as template\n\nimport zstandard # pip install zstandard\n\nQUANT_PATH = \"final_model.int8.ptz\"\n\nwith open(QUANT_PATH, \"rb\") as f:\n quant_blob = f.read()\n\n# Decompress\ndecompressed = zstandard.ZstdDecompressor().decompress(quant_blob)\nquant_state = torch.load(io.BytesIO(decompressed), map_location=\"cpu\", weights_only=False)\n\n# Get a template state dict for shape/dtype info\ntemplate_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}\n\n# Dequantize\ndeq_state = tg.dequantize_mixed_int6(quant_state[\"w\"], quant_state[\"m\"], template_sd)\n\n# Handle AWQ inverse scales — undo the column scaling applied before quantization\nawq_inv_keys = [k for k in deq_state if k.startswith(\"_awq_inv_scale.\")]\nprint(f\"AWQ inverse scale keys found: {len(awq_inv_keys)}\")\nfor inv_key in awq_inv_keys:\n inv_scale = deq_state.pop(inv_key).float()\n layer_name = inv_key.replace(\"_awq_inv_scale.\", \"\")\n weight_key = layer_name + \".weight\"\n if weight_key in deq_state:\n deq_state[weight_key] = (deq_state[weight_key].float() * inv_scale.unsqueeze(0)).to(template_sd[weight_key].dtype)\n print(f\" unscaled {weight_key}\")\n\n# Remove any remaining AWQ metadata keys\ndeq_state = {k: v for k, v in deq_state.items() if not k.startswith(\"_awq_\")}\n\n# Load into model\nmodel.load_state_dict(deq_state, strict=True)\nmodel = model.to(DEVICE).eval()\nprint(f\"\\nLoaded quantized model from {QUANT_PATH}\")\nprint(f\"Artifact size: {len(quant_blob):,} bytes ({len(quant_blob)/1024/1024:.2f} MB)\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sentencepiece as spm\n", - "\n", - "sp = spm.SentencePieceProcessor()\n", - "sp.Load(TOKENIZER_PATH)\n", - "\n", - "print(f\"Tokenizer loaded: vocab_size={sp.GetPieceSize()}\")\n", - "print(f\"Example: 'Hello world' -> {sp.Encode('Hello world')}\")\n", - "print(f\"Decoded back: '{sp.Decode(sp.Encode('Hello world'))}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Generation utilities" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "@torch.no_grad()\ndef generate(\n model,\n tokenizer,\n prompt: str,\n max_new_tokens: int = 200,\n temperature: float = 0.8,\n top_k: int = 50,\n top_p: float = 0.9,\n device: str = \"cuda\",\n seq_len: int = 2048,\n) -> str:\n \"\"\"Generate text from a prompt using the trained model.\"\"\"\n model.eval()\n \n # Tokenize prompt\n tokens = tokenizer.Encode(prompt)\n input_ids = torch.tensor([tokens], dtype=torch.long, device=device)\n \n generated = list(tokens)\n \n for _ in range(max_new_tokens):\n # Truncate to seq_len if needed\n if input_ids.shape[1] > seq_len:\n input_ids = input_ids[:, -seq_len:]\n \n # Forward pass — forward_logits already applies softcap\n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids) # [1, seq_len, vocab]\n \n # Get logits for last position (softcap already applied)\n next_logits = logits[0, -1, :].float()\n \n # Temperature\n if temperature > 0:\n next_logits = next_logits / temperature\n \n # Top-k filtering\n if top_k > 0:\n top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))\n next_logits[next_logits < top_k_vals[-1]] = float('-inf')\n \n # Top-p (nucleus) filtering\n if top_p < 1.0:\n sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n sorted_indices_to_remove = cumulative_probs > top_p\n sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()\n sorted_indices_to_remove[0] = False\n indices_to_remove = sorted_indices[sorted_indices_to_remove]\n next_logits[indices_to_remove] = float('-inf')\n \n # Sample\n probs = F.softmax(next_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1).item()\n \n generated.append(next_token)\n input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)\n \n return tokenizer.Decode(generated)\n\n\n@torch.no_grad()\ndef compute_perplexity(model, tokenizer, text: str, device: str = \"cuda\", seq_len: int = 2048) -> float:\n \"\"\"Compute perplexity of the model on a given text.\"\"\"\n model.eval()\n tokens = tokenizer.Encode(text)\n if len(tokens) < 2:\n return float('inf')\n \n input_ids = torch.tensor([tokens[:seq_len]], dtype=torch.long, device=device)\n target_ids = torch.tensor([tokens[1:seq_len+1]], dtype=torch.long, device=device)\n \n # Trim to same length\n min_len = min(input_ids.shape[1], target_ids.shape[1])\n input_ids = input_ids[:, :min_len]\n target_ids = target_ids[:, :min_len]\n \n with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16, enabled=(device != 'cpu')):\n logits = model.forward_logits(input_ids)\n \n loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1))\n return math.exp(loss.item())\n\n\nprint(\"Generation utilities loaded.\")" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate from a prompt\n", - "prompt = \"The history of artificial intelligence began\"\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=200,\n", - " temperature=0.8,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"\\nGenerated:\\n{output}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Try different prompts\n", - "prompts = [\n", - " \"In a small village by the sea, there lived\",\n", - " \"The most important scientific discovery of the 21st century\",\n", - " \"def fibonacci(n):\\n\",\n", - " \"Once upon a time\",\n", - "]\n", - "\n", - "for p in prompts:\n", - " out = generate(model, sp, p, max_new_tokens=100, temperature=0.7, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Prompt: {p}\")\n", - " print(f\"Output: {out}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Compute perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_texts = [\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"Machine learning is a subset of artificial intelligence that focuses on building systems that learn from data.\",\n", - " \"asdf jkl qwerty zxcv random gibberish text that should have high perplexity\",\n", - "]\n", - "\n", - "for text in test_texts:\n", - " ppl = compute_perplexity(model, sp, text, device=DEVICE, seq_len=args.train_seq_len)\n", - " print(f\"Perplexity: {ppl:.2f} | Text: {text[:60]}...\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Interactive generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive — change the prompt and re-run\n", - "prompt = \"Today I learned that\" # <-- Edit this!\n", - "temperature = 0.8 # Lower = more deterministic, higher = more creative\n", - "max_tokens = 150\n", - "\n", - "output = generate(\n", - " model, sp, prompt,\n", - " max_new_tokens=max_tokens,\n", - " temperature=temperature,\n", - " top_k=50,\n", - " top_p=0.9,\n", - " device=DEVICE,\n", - " seq_len=args.train_seq_len,\n", - ")\n", - "print(output)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/records/phase3/exp07_parallel-muon_from-exp06/README.md b/records/phase3/exp07_parallel-muon_from-exp06/README.md deleted file mode 100644 index 95ee7dda54..0000000000 --- a/records/phase3/exp07_parallel-muon_from-exp06/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# Parallel Muon (Batched Newton-Schulz) - -## Score: val_bpb = TBD - -## Hypothesis - -Grouping weight matrices by shape and running batched Newton-Schulz orthogonalization reduces per-step overhead. Community achieves 83ms/step with parameter banking. On 1×GPU, this is a minor speedup; on 8×H100, it's critical (more steps in 10 minutes = lower BPB). - -## Changes from exp06 - -- Added `zeropower_via_newtonschulz5_batched()` for 3D tensor [batch, rows, cols] -- Muon optimizer now groups params by shape, batches same-shape matrices for NS5 -- Momentum applied first to all grads, then batched NS5, then all-reduce -- Correctness: should produce identical val_bpb to exp06 on 1×GPU (same math, different execution order) - -## Architecture - -Inherits from exp06 (all features) + Parallel Muon optimizer. - -## Expected Impact - -Same val_bpb as exp06 on 1×GPU. Faster ms/step → more training steps within wallclock on multi-GPU. - -## Results - -TBD — awaiting A100 run. diff --git a/records/phase3/exp07_parallel-muon_from-exp06/logs.txt b/records/phase3/exp07_parallel-muon_from-exp06/logs.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/phase3/exp07_parallel-muon_from-exp06/run.sh b/records/phase3/exp07_parallel-muon_from-exp06/run.sh deleted file mode 100755 index 673f205692..0000000000 --- a/records/phase3/exp07_parallel-muon_from-exp06/run.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# ============================================================ -# Exp07: Parallel Muon (batched Newton-Schulz) — from Exp06 -# Groups weight matrices by shape for batched NS5 orthogonalization. -# Critical for 8xH100 throughput (83ms/step target). -# On 1xGPU: validates correctness, minor speedup from batching. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="exp07_parallel-muon_from-exp06" -cd /workspace/parameter-golf -export EVAL_STRIDE=0 -export SEED="${SEED:-42}" -export ITERATIONS="${ITERATIONS:-20000}" -export WARMUP_STEPS="${WARMUP_STEPS:-20}" -export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" -export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" -export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-100}" -export VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-262144}" -export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-25}" -export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MLP_ACTIVATION=leaky_relu_sq -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 -export ROPE_FRAC=0.25 -export EMA_ENABLED=1 -export EMA_DECAY=0.997 -export XSA_LAST_N=4 -export QAT_THRESHOLD=0.1 -export USE_GPTQ=1 -export GPTQ_DAMP=0.01 -export GPTQ_CALIB_BATCHES=8 -export RUN_ID="${EXP_NAME}" -echo "=== ${EXP_NAME} ===" -python3 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${SCRIPT_DIR}/logs.txt" -echo "=== ${EXP_NAME} COMPLETE ===" diff --git a/records/phase3/exp07_parallel-muon_from-exp06/save_model.py b/records/phase3/exp07_parallel-muon_from-exp06/save_model.py deleted file mode 100644 index f8b8d261e2..0000000000 --- a/records/phase3/exp07_parallel-muon_from-exp06/save_model.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -"""Save a trained model checkpoint for exp07_parallel-muon_from-exp06.""" - -import argparse -import json -import os -import sys -import shutil - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-pt", type=str, default="final_model.pt") - parser.add_argument("--output-dir", type=str, default="model_checkpoint") - args = parser.parse_args() - - os.makedirs(args.output_dir, exist_ok=True) - - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - import train_gpt as tg - - hp = tg.Hyperparameters() - config = { - "vocab_size": hp.vocab_size, - "num_layers": hp.num_layers, - "model_dim": hp.model_dim, - "num_heads": hp.num_heads, - "num_kv_heads": hp.num_kv_heads, - "mlp_mult": hp.mlp_mult, - "tie_embeddings": hp.tie_embeddings, - "tied_embed_init_std": hp.tied_embed_init_std, - "logit_softcap": hp.logit_softcap, - "rope_base": hp.rope_base, - "qk_gain_init": hp.qk_gain_init, - "bigram_vocab_size": hp.bigram_vocab_size, - "bigram_dim": hp.bigram_dim, - "unique_layers": hp.unique_layers, - "rope_frac": hp.rope_frac, - "ema_decay": hp.ema_decay, - "xsa_last_n": hp.xsa_last_n, - "qat_threshold": hp.qat_threshold, - "use_gptq": hp.use_gptq, - "train_seq_len": hp.train_seq_len, - } - - with open(os.path.join(args.output_dir, "config.json"), "w") as f: - json.dump(config, f, indent=2) - - if os.path.exists(args.model_pt): - shutil.copy2(args.model_pt, os.path.join(args.output_dir, "model.pt")) - quant_path = args.model_pt.replace(".pt", ".int8.ptz") - if os.path.exists(quant_path): - shutil.copy2(quant_path, os.path.join(args.output_dir, "model_quant.ptz")) - - print(f"Checkpoint saved to {args.output_dir}/") - -if __name__ == "__main__": - main() diff --git a/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py b/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py deleted file mode 100644 index ba77372e88..0000000000 --- a/records/phase3/exp07_parallel-muon_from-exp06/train_gpt.py +++ /dev/null @@ -1,1542 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - rope_frac = float(os.environ.get("ROPE_FRAC", 0.25)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) - use_gptq = bool(int(os.environ.get("USE_GPTQ", "1"))) - gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.01)) - gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 8)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - -# ----------------------------- -# MUON OPTIMIZER (Parallel / Batched) -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -def zeropower_via_newtonschulz5_batched(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz for 3D tensor [batch, rows, cols].""" - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - # Normalize each matrix in the batch - norms = X.flatten(1).norm(dim=1)[:, None, None].clamp_min(eps) - X = X / norms - transposed = X.size(1) > X.size(2) - if transposed: - X = X.transpose(1, 2) - for _ in range(steps): - A = X @ X.transpose(1, 2) - B = b * A + c * A @ A - X = a * X + B @ X - return X.transpose(1, 2) if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - # Group parameters by shape for batched Newton-Schulz - self._shape_groups: dict[tuple[int, int], list[int]] | None = None - - def _build_shape_groups(self, params: list) -> dict[tuple[int, int], list[int]]: - groups: dict[tuple[int, int], list[int]] = {} - for i, p in enumerate(params): - shape = (p.shape[0], p.shape[1]) - groups.setdefault(shape, []).append(i) - return groups - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - # Build shape groups on first call - if self._shape_groups is None: - self._shape_groups = self._build_shape_groups(params) - - # Apply momentum to all grads first - nesterov_grads: list[Tensor | None] = [None] * len(params) - for i, p in enumerate(params): - if p.grad is None: - continue - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - nesterov_grads[i] = g.add(buf, alpha=momentum) - else: - nesterov_grads[i] = buf.clone() - - # Batched Newton-Schulz per shape group - updates: list[Tensor | None] = [None] * len(params) - for shape, indices in self._shape_groups.items(): - # Filter to indices assigned to this rank (for distributed) and with grads - my_indices = [idx for idx in indices if idx % world_size == rank and nesterov_grads[idx] is not None] - if not my_indices: - continue - if len(my_indices) > 1: - # Batch: stack into 3D and process together - batch = torch.stack([nesterov_grads[idx] for idx in my_indices]) - batch_out = zeropower_via_newtonschulz5_batched(batch, steps=backend_steps) - scale = max(1, shape[0] / shape[1]) ** 0.5 - for j, idx in enumerate(my_indices): - updates[idx] = batch_out[j] * scale - else: - # Single matrix: use original function - idx = my_indices[0] - g = zeropower_via_newtonschulz5(nesterov_grads[idx], steps=backend_steps) - g *= max(1, shape[0] / shape[1]) ** 0.5 - updates[idx] = g - - # All-reduce updates - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if updates[i] is not None: - updates_flat[curr : curr + p.numel()] = updates[i].reshape(-1).bfloat16() - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - - -def gptq_quantize_layer(W: Tensor, H: Tensor, clip_range: int = 31, damp: float = 0.01) -> tuple[Tensor, Tensor]: - """GPTQ: second-order weight quantization using Hessian information. - W: [out_features, in_features] weight matrix - H: [in_features, in_features] Hessian (X^T X from calibration) - Returns (q_int8, scale_fp16) same as quantize_intN_per_row. - """ - W = W.float() - out_dim, in_dim = W.shape - # Per-row scale (same as naive) - row_max = W.abs().amax(dim=1).clamp_min(1e-12) - scale = (row_max / clip_range).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - scale_f = scale.float() - - # Dampen Hessian diagonal for numerical stability - diag = torch.diag(H) - H = H + damp * diag.mean() * torch.eye(in_dim, device=H.device, dtype=H.dtype) - - # Cholesky decomposition - try: - L = torch.linalg.cholesky(H) - Linv = torch.linalg.inv(L) - Hinv = Linv.T @ Linv - except RuntimeError: - # Fallback to naive if Cholesky fails - return quantize_intN_per_row(W, clip_range) - - Q = torch.zeros_like(W) - E = torch.zeros_like(W) - W_copy = W.clone() - - # Process columns sequentially (GPTQ algorithm) - for j in range(in_dim): - w_col = W_copy[:, j] - d = Hinv[j, j].clamp_min(1e-12) - q_col = torch.clamp(torch.round(w_col / scale_f), -(clip_range + 1), clip_range) - Q[:, j] = q_col - err = (w_col - q_col * scale_f) / d - E[:, j] = err - # Update remaining columns - if j + 1 < in_dim: - W_copy[:, j + 1:] -= err[:, None] * Hinv[j, j + 1:][None, :] - - return Q.to(torch.int8), scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], - gptq_hessians: dict[str, Tensor] | None = None, gptq_damp: float = 0.01): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - # Build a mapping from state_dict weight keys to module names for GPTQ lookup - # e.g. "blocks.0.attn.c_q.weight" -> "blocks.0.attn.c_q" - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - # Try GPTQ if Hessian available - module_name = name.replace(".weight", "") if name.endswith(".weight") else None - if gptq_hessians and module_name and module_name in gptq_hessians and t.ndim == 2: - H = gptq_hessians[module_name] - q, s = gptq_quantize_layer(t, H, clip_range=clip, damp=gptq_damp) - else: - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, use_xsa: bool = False): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - self.use_xsa = use_xsa - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - self.rope_dims = max(2, int(self.head_dim * rope_frac) // 2 * 2) # even, at least 2 - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.rope_dims, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - # Partial RoPE: apply rotary only to first rope_dims dimensions - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] - k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] - q_rope = apply_rotary_emb(q_rope, cos, sin) - k_rope = apply_rotary_emb(k_rope, cos, sin) - q = torch.cat([q_rope, q_pass], dim=-1) - k = torch.cat([k_rope, k_pass], dim=-1) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - # XSA: subtract self-value to force reliance on cross-token information - if self.use_xsa: - # Expand v for GQA if needed - if self.num_kv_heads != self.num_heads: - rep = self.num_heads // self.num_kv_heads - v_expanded = v.repeat_interleave(rep, dim=1) - else: - v_expanded = v - # Self-attention weight for diagonal is 1/seq_len in expectation, - # but we subtract the actual self-value contribution: v_i / scale - # Simpler approach: subtract v_i projected orthogonally - y = y - v_expanded / seqlen - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = F.leaky_relu(self.fc(x), negative_slope=0.5) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_frac: float = 0.25, layer_idx: int = 0, use_xsa: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_frac=rope_frac, use_xsa=use_xsa) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale = 1.0 / math.sqrt(layer_idx + 1) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.ln_scale * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.ln_scale * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - rope_frac: float = 0.25, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - rope_frac=rope_frac, layer_idx=i, use_xsa=(i >= n_unique - xsa_last_n)) - for i in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - rope_frac=args.rope_frac, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Late QAT: apply fake quantization via STE when lr_scale < threshold - qat_active = args.qat_threshold > 0 and scale < args.qat_threshold - if qat_active: - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - cat = _classify_param(name) - clip = 15 if cat == "mlp" else 31 # match int5/int6 - with torch.no_grad(): - t32 = param.float() - row_max = t32.abs().amax(dim=1).clamp_min(1e-12) - s = row_max / clip - q = torch.clamp(torch.round(t32 / s[:, None]), -(clip + 1), clip) - dq = (q * s[:, None]).to(param.dtype) - # STE: replace param data but keep grad flowing - param.data.copy_(dq) - - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # EMA: exponential moving average of weights every step - if args.ema_enabled and ema_state is not None: - decay = args.ema_decay - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1 - decay) - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply EMA weights - if args.ema_enabled and ema_state is not None: - log0(f"ema:applying decay={args.ema_decay}") - current_state = base_model.state_dict() - ema_load = { - name: tensor.to(dtype=current_state[name].dtype) - for name, tensor in ema_state.items() - } - base_model.load_state_dict(ema_load, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # GPTQ: collect Hessians for second-order quantization - gptq_hessians: dict[str, Tensor] = {} - if args.use_gptq: - log0(f"gptq:collecting Hessians (calib_batches={args.gptq_calib_batches})") - gptq_hooks = [] - gptq_inp_cache: dict[str, list[Tensor]] = {} - - def make_gptq_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - x_2d = x.detach().float().reshape(-1, x.shape[-1]) - if name not in gptq_inp_cache: - gptq_inp_cache[name] = [] - gptq_inp_cache[name].append(x_2d.cpu()) - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2 and module.weight.numel() > 65536: - gptq_hooks.append(module.register_forward_hook(make_gptq_hook(name))) - - base_model.eval() - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(args.gptq_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in gptq_hooks: - h.remove() - - # Build Hessians: H = X^T X / n_samples - for name, inp_list in gptq_inp_cache.items(): - X = torch.cat(inp_list, dim=0) - n = X.shape[0] - H = (X.T @ X) / n - gptq_hessians[name] = H - del gptq_inp_cache - log0(f"gptq:collected Hessians for {len(gptq_hessians)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, gptq_hessians=gptq_hessians, gptq_damp=args.gptq_damp) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned From 9083b57f306b52c44ac9b4ad1a4465225c5ad0b0 Mon Sep 17 00:00:00 2001 From: SPThole Date: Mon, 13 Apr 2026 23:48:01 +0530 Subject: [PATCH 08/10] updt --- .../README.md | 91 - ...-cyclic-relusq-11Lshared_8xH100_seed42.txt | 1520 ---------------- ...-cyclic-relusq-11Lshared_8xH100_seed43.txt | 1522 ----------------- ...-cyclic-relusq-11Lshared_8xH100_seed44.txt | 1521 ---------------- .../2025-03-24_AWQ_CyclMom_11L_shared/run.sh | 79 - .../setup_and_run.sh | 61 - .../submission.json | 11 - .../train_gpt.py | 1340 --------------- 8 files changed, 6145 deletions(-) delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed42.txt delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt delete mode 100755 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh delete mode 100755 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json delete mode 100644 records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md deleted file mode 100644 index b0f2cff824..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/README.md +++ /dev/null @@ -1,91 +0,0 @@ -# AWQ + Cyclic Momentum + ReLU² + 11L Shared — val_bpb ≈ 1.1507 - -**Track:** 10min / 16MB -**Hardware:** 8×H100 SXM, 600s wallclock -**Model size:** ~15.4 MB (int5 MLP / int6 attn + zstd) -**val_bpb:** 1.1507 ± 0.0016 (3-seed mean, seeds 42, 43, 44) - -## Key Innovations - -Starting from the community SOTA baseline (thwu1), we introduce four techniques: - -| Technique | Description | Impact | -|-----------|-------------|--------| -| **AWQ (Activation-Aware Weight Quantization)** | Scale weight columns by activation importance (alpha=0.5) before quantization. Folds compensation into preceding LayerNorm. Reduces quantization error on high-activation channels. | Quant gap 0.027 → 0.010 bpb | -| **Cyclic Muon Momentum** | Triangle wave between 0.85–0.95 with period=50 steps, replacing fixed 0.99 after warmup. Prevents optimizer from settling into sharp minima. | −0.0045 bpb on 1×H100 | - -## Results (8×H100) - -| Seed | Steps | Raw val_bpb | Quantized val_bpb | Model Size | Total Size | -|------|-------|-------------|-------------------|------------|------------| -| 42 | 5999 | 1.1605 | 1.1502 | 15.39 MB | 15.45 MB | -| 43 | 6005 | 1.1598 | 1.1494 | 15.46 MB | 15.52 MB | -| 44 | 6000 | 1.1619 | 1.1526 | 15.37 MB | 15.43 MB | -| **Mean** | **6001** | **1.1607** | **1.1507** | **15.41 MB** | | -| **Std** | | **0.0011** | **0.0016** | | | - -## Architecture - -| Parameter | Value | -|-----------|-------| -| num_layers | 11 (10 unique, last shared) | -| model_dim | 512 | -| num_heads | 8 | -| num_kv_heads | 4 (GQA) | -| mlp_mult | 3.0 (hidden=1536) | -| mlp_activation | relu_sq | -| vocab_size | 1024 | -| train_seq_len | 2048 | -| tie_embeddings | yes | -| logit_softcap | 30.0 | -| rope_base | 10000 | -| rope_dims | 64 (full) | -| bigram_vocab_size | 10240 | -| bigram_dim | 128 | -| skip_connections | U-Net (5 encoder, 6 decoder) | - -## Training - -| Parameter | Value | -|-----------|-------| -| train_batch_tokens | 786,432 | -| optimizer (matrices) | Muon, lr=0.025, momentum=cyclic 0.85–0.95 | -| optimizer (embeds/scalars) | AdamW, lr=0.035/0.025 | -| warmup_steps | 20 | -| warmdown_iters | 3500 | -| weight_decay | 0.04 | -| grad_clip_norm | 0.3 | -| muon_momentum_warmup | 0.92 → cyclic over 1500 steps | -| SWA | start_frac=0.2, every=50 steps | - -## Quantization - -| Component | Precision | -|-----------|-----------| -| MLP weights | int5 per-row | -| Attention weights | int6 per-row | -| Bigram embeddings | int6 per-row | -| Token embeddings | int8 per-row | -| Control tensors | fp32 passthrough | -| Compression | zstd | - -**AWQ:** Before quantization, run 8 calibration batches through the model. For each Linear layer, compute per-channel activation magnitude `s = act.abs().mean(dim=(0,1))`. Scale weight columns by `s^0.5`, fold inverse into preceding LayerNorm. This protects high-activation channels from quantization error. - -## Evaluation - -Sliding window evaluation with stride=64, batch_seqs=64. - -```bash -torchrun --nproc_per_node=8 train_gpt.py -``` - -## Development Journey - -This submission emerged from 21 experiments on 1×H100 and 1×A40, systematically testing: -- Multi-token prediction (MTP) — marginal gains, size overhead -- Curriculum learning — incompatible with torch.compile -- Test-time training (TTT) — promising with partial RoPE, but eval time too long -- Various quantization strategies (GPTQ-lite, layer-aware, EMA) — AWQ was the clear winner -- Architectural variations (wider, value embeddings, partial RoPE) — diminishing returns - -The final recipe: simple architecture + smart optimization (cyclic momentum) + smart quantization (AWQ). diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed42.txt b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed42.txt deleted file mode 100644 index b81a4caced..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed42.txt +++ /dev/null @@ -1,1520 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 24 11:24:37 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 42C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | -| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | -| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 41C P0 130W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 42C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | -| N/A 34C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | -| N/A 41C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9334 train_time:148ms step_avg:148.19ms -step:2/20000 train_loss:8.7695 train_time:229ms step_avg:114.65ms -step:3/20000 train_loss:8.0097 train_time:326ms step_avg:108.71ms -step:4/20000 train_loss:7.2373 train_time:423ms step_avg:105.66ms -step:5/20000 train_loss:7.1032 train_time:521ms step_avg:104.13ms -step:6/20000 train_loss:6.9653 train_time:620ms step_avg:103.34ms -step:7/20000 train_loss:6.8161 train_time:718ms step_avg:102.55ms -step:8/20000 train_loss:6.8162 train_time:816ms step_avg:101.98ms -step:9/20000 train_loss:6.5146 train_time:913ms step_avg:101.41ms -step:10/20000 train_loss:6.1109 train_time:1011ms step_avg:101.06ms -step:100/20000 train_loss:3.1632 train_time:9982ms step_avg:99.82ms -step:200/20000 train_loss:2.3797 train_time:19987ms step_avg:99.94ms -step:300/20000 train_loss:2.5390 train_time:30012ms step_avg:100.04ms -step:400/20000 train_loss:2.4119 train_time:40058ms step_avg:100.15ms -step:500/20000 train_loss:2.3950 train_time:50068ms step_avg:100.14ms -step:500/20000 val_loss:2.3560 val_bpb:1.3954 train_time:50096ms step_avg:100.19ms -step:600/20000 train_loss:2.3334 train_time:60135ms step_avg:100.23ms -step:700/20000 train_loss:2.3465 train_time:70200ms step_avg:100.29ms -step:800/20000 train_loss:2.2394 train_time:80264ms step_avg:100.33ms -step:900/20000 train_loss:2.1327 train_time:90309ms step_avg:100.34ms -step:1000/20000 train_loss:2.2778 train_time:100312ms step_avg:100.31ms -step:1000/20000 val_loss:2.2303 val_bpb:1.3209 train_time:100339ms step_avg:100.34ms -step:1100/20000 train_loss:2.3252 train_time:110366ms step_avg:100.33ms -step:1200/20000 train_loss:2.3585 train_time:120420ms step_avg:100.35ms -step:1300/20000 train_loss:2.1034 train_time:130470ms step_avg:100.36ms -step:1400/20000 train_loss:2.1910 train_time:140517ms step_avg:100.37ms -step:1500/20000 train_loss:2.2249 train_time:150502ms step_avg:100.33ms -step:1500/20000 val_loss:2.1887 val_bpb:1.2962 train_time:150528ms step_avg:100.35ms -step:1600/20000 train_loss:2.0341 train_time:160544ms step_avg:100.34ms -step:1700/20000 train_loss:2.1021 train_time:170584ms step_avg:100.34ms -step:1800/20000 train_loss:2.1251 train_time:180628ms step_avg:100.35ms -step:1900/20000 train_loss:2.1003 train_time:190601ms step_avg:100.32ms -step:2000/20000 train_loss:2.0483 train_time:200615ms step_avg:100.31ms -step:2000/20000 val_loss:2.1119 val_bpb:1.2508 train_time:200641ms step_avg:100.32ms -step:2100/20000 train_loss:2.0367 train_time:210626ms step_avg:100.30ms -step:2200/20000 train_loss:2.1896 train_time:220638ms step_avg:100.29ms -step:2300/20000 train_loss:2.1112 train_time:230668ms step_avg:100.29ms -step:2400/20000 train_loss:2.0746 train_time:240605ms step_avg:100.25ms -step:2500/20000 train_loss:2.1813 train_time:250618ms step_avg:100.25ms -step:2500/20000 val_loss:2.1205 val_bpb:1.2559 train_time:250645ms step_avg:100.26ms -step:2600/20000 train_loss:2.1164 train_time:260617ms step_avg:100.24ms -step:2700/20000 train_loss:2.1148 train_time:270603ms step_avg:100.22ms -step:2800/20000 train_loss:2.1692 train_time:280673ms step_avg:100.24ms -step:2900/20000 train_loss:2.0438 train_time:290623ms step_avg:100.21ms -step:3000/20000 train_loss:2.1738 train_time:300599ms step_avg:100.20ms -step:3000/20000 val_loss:2.1053 val_bpb:1.2469 train_time:300625ms step_avg:100.21ms -step:3100/20000 train_loss:2.0536 train_time:310605ms step_avg:100.20ms -step:3200/20000 train_loss:2.1821 train_time:320597ms step_avg:100.19ms -step:3300/20000 train_loss:2.0807 train_time:330522ms step_avg:100.16ms -step:3400/20000 train_loss:2.0249 train_time:340504ms step_avg:100.15ms -step:3500/20000 train_loss:2.1864 train_time:350488ms step_avg:100.14ms -step:3500/20000 val_loss:2.0867 val_bpb:1.2358 train_time:350515ms step_avg:100.15ms -step:3600/20000 train_loss:2.0971 train_time:360481ms step_avg:100.13ms -step:3700/20000 train_loss:2.0972 train_time:370467ms step_avg:100.13ms -step:3800/20000 train_loss:2.0745 train_time:380402ms step_avg:100.11ms -step:3900/20000 train_loss:2.0771 train_time:390409ms step_avg:100.10ms -step:4000/20000 train_loss:1.9767 train_time:400381ms step_avg:100.10ms -step:4000/20000 val_loss:2.0649 val_bpb:1.2230 train_time:400407ms step_avg:100.10ms -step:4100/20000 train_loss:2.0111 train_time:410379ms step_avg:100.09ms -step:4200/20000 train_loss:2.1512 train_time:420355ms step_avg:100.08ms -step:4300/20000 train_loss:2.0497 train_time:430284ms step_avg:100.07ms -step:4400/20000 train_loss:2.0278 train_time:440260ms step_avg:100.06ms -step:4500/20000 train_loss:2.1155 train_time:450253ms step_avg:100.06ms -step:4500/20000 val_loss:2.0376 val_bpb:1.2068 train_time:450280ms step_avg:100.06ms -step:4600/20000 train_loss:1.8352 train_time:460233ms step_avg:100.05ms -step:4700/20000 train_loss:2.2260 train_time:470153ms step_avg:100.03ms -step:4800/20000 train_loss:2.4265 train_time:480133ms step_avg:100.03ms -step:4900/20000 train_loss:2.0406 train_time:490112ms step_avg:100.02ms -step:5000/20000 train_loss:2.0948 train_time:500100ms step_avg:100.02ms -step:5000/20000 val_loss:2.0129 val_bpb:1.1922 train_time:500126ms step_avg:100.03ms -step:5100/20000 train_loss:2.1123 train_time:510082ms step_avg:100.02ms -step:5200/20000 train_loss:2.0305 train_time:520004ms step_avg:100.00ms -step:5300/20000 train_loss:1.9923 train_time:529971ms step_avg:99.99ms -swa:start step:5350 -step:5400/20000 train_loss:2.0323 train_time:540022ms step_avg:100.00ms -step:5500/20000 train_loss:2.0009 train_time:550034ms step_avg:100.01ms -step:5500/20000 val_loss:1.9842 val_bpb:1.1752 train_time:550089ms step_avg:100.02ms -step:5600/20000 train_loss:1.9365 train_time:560089ms step_avg:100.02ms -step:5700/20000 train_loss:1.9912 train_time:570062ms step_avg:100.01ms -step:5800/20000 train_loss:1.9716 train_time:580099ms step_avg:100.02ms -step:5900/20000 train_loss:1.8795 train_time:590129ms step_avg:100.02ms -step:5999/20000 val_loss:1.9594 val_bpb:1.1605 train_time:600082ms step_avg:100.03ms -stopping_early: wallclock_cap train_time:600082ms step:5999/20000 -peak memory allocated: 20841 MiB reserved: 21060 MiB -swa:applying averaged 13 checkpoints -Serialized model: 98437419 bytes -Code size: 58616 bytes -Total submission size: 98496035 bytes -awq:calibrating alpha=0.5 -awq:scaled 61 layers -Serialized model int6+zstd: 15394174 bytes -Total submission size int8+zlib: 15452790 bytes -awq:unscaled 61 layers after dequant -final_eval_mode:sliding_window stride:64 batch_seqs:64 -final_int8_zlib_roundtrip val_loss:1.9420 val_bpb:1.1502 eval_time:180059ms -final_int8_zlib_roundtrip_exact val_loss:1.94198177 val_bpb:1.15015403 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt deleted file mode 100644 index b52ee75a72..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed43.txt +++ /dev/null @@ -1,1522 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 24 11:39:20 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 47C P0 129W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | -| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 47C P0 134W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 48C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | -| N/A 38C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | -| N/A 46C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:43 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9285 val_bpb:4.1035 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9296 train_time:149ms step_avg:148.74ms -step:2/20000 train_loss:8.4791 train_time:221ms step_avg:110.74ms -step:3/20000 train_loss:7.8364 train_time:318ms step_avg:105.92ms -step:4/20000 train_loss:7.2856 train_time:415ms step_avg:103.85ms -step:5/20000 train_loss:7.0369 train_time:514ms step_avg:102.84ms -step:6/20000 train_loss:6.8438 train_time:612ms step_avg:102.03ms -step:7/20000 train_loss:6.7871 train_time:711ms step_avg:101.54ms -step:8/20000 train_loss:6.8742 train_time:810ms step_avg:101.19ms -step:9/20000 train_loss:6.5695 train_time:908ms step_avg:100.86ms -step:10/20000 train_loss:6.2076 train_time:1005ms step_avg:100.52ms -step:100/20000 train_loss:3.1589 train_time:10032ms step_avg:100.32ms -step:200/20000 train_loss:2.3736 train_time:20030ms step_avg:100.15ms -step:300/20000 train_loss:2.5403 train_time:30058ms step_avg:100.19ms -step:400/20000 train_loss:2.4110 train_time:40100ms step_avg:100.25ms -step:500/20000 train_loss:2.3962 train_time:50106ms step_avg:100.21ms -step:500/20000 val_loss:2.3542 val_bpb:1.3943 train_time:50133ms step_avg:100.27ms -step:600/20000 train_loss:2.3361 train_time:60176ms step_avg:100.29ms -step:700/20000 train_loss:2.3427 train_time:70239ms step_avg:100.34ms -step:800/20000 train_loss:2.2381 train_time:80296ms step_avg:100.37ms -step:900/20000 train_loss:2.1292 train_time:90346ms step_avg:100.38ms -step:1000/20000 train_loss:2.2797 train_time:100338ms step_avg:100.34ms -step:1000/20000 val_loss:2.2304 val_bpb:1.3210 train_time:100365ms step_avg:100.36ms -step:1100/20000 train_loss:2.3257 train_time:110393ms step_avg:100.36ms -step:1200/20000 train_loss:2.3582 train_time:120427ms step_avg:100.36ms -step:1300/20000 train_loss:2.1011 train_time:130471ms step_avg:100.36ms -step:1400/20000 train_loss:2.1876 train_time:140507ms step_avg:100.36ms -step:1500/20000 train_loss:2.2261 train_time:150479ms step_avg:100.32ms -step:1500/20000 val_loss:2.1915 val_bpb:1.2979 train_time:150505ms step_avg:100.34ms -step:1600/20000 train_loss:2.0294 train_time:160501ms step_avg:100.31ms -step:1700/20000 train_loss:2.1036 train_time:170525ms step_avg:100.31ms -step:1800/20000 train_loss:2.1260 train_time:180535ms step_avg:100.30ms -step:1900/20000 train_loss:2.0983 train_time:190498ms step_avg:100.26ms -step:2000/20000 train_loss:2.0471 train_time:200512ms step_avg:100.26ms -step:2000/20000 val_loss:2.1136 val_bpb:1.2518 train_time:200539ms step_avg:100.27ms -step:2100/20000 train_loss:2.0326 train_time:210518ms step_avg:100.25ms -step:2200/20000 train_loss:2.1266 train_time:220516ms step_avg:100.23ms -step:2300/20000 train_loss:2.1105 train_time:230534ms step_avg:100.23ms -step:2400/20000 train_loss:2.0723 train_time:240465ms step_avg:100.19ms -step:2500/20000 train_loss:2.1760 train_time:250447ms step_avg:100.18ms -step:2500/20000 val_loss:2.1204 val_bpb:1.2558 train_time:250473ms step_avg:100.19ms -step:2600/20000 train_loss:2.1216 train_time:260431ms step_avg:100.17ms -step:2700/20000 train_loss:2.1178 train_time:270423ms step_avg:100.16ms -step:2800/20000 train_loss:2.1704 train_time:280417ms step_avg:100.15ms -step:2900/20000 train_loss:2.0422 train_time:290367ms step_avg:100.13ms -step:3000/20000 train_loss:2.1734 train_time:300362ms step_avg:100.12ms -step:3000/20000 val_loss:2.1066 val_bpb:1.2476 train_time:300388ms step_avg:100.13ms -step:3100/20000 train_loss:2.0504 train_time:310354ms step_avg:100.11ms -step:3200/20000 train_loss:2.1826 train_time:320337ms step_avg:100.11ms -step:3300/20000 train_loss:2.0797 train_time:330258ms step_avg:100.08ms -step:3400/20000 train_loss:2.0276 train_time:340246ms step_avg:100.07ms -step:3500/20000 train_loss:2.1849 train_time:350219ms step_avg:100.06ms -step:3500/20000 val_loss:2.0866 val_bpb:1.2358 train_time:350244ms step_avg:100.07ms -step:3600/20000 train_loss:2.0964 train_time:360200ms step_avg:100.06ms -step:3700/20000 train_loss:2.0972 train_time:370186ms step_avg:100.05ms -step:3800/20000 train_loss:2.0760 train_time:380100ms step_avg:100.03ms -step:3900/20000 train_loss:2.0754 train_time:390084ms step_avg:100.02ms -step:4000/20000 train_loss:1.9785 train_time:400040ms step_avg:100.01ms -step:4000/20000 val_loss:2.0639 val_bpb:1.2224 train_time:400067ms step_avg:100.02ms -step:4100/20000 train_loss:2.0115 train_time:410015ms step_avg:100.00ms -step:4200/20000 train_loss:2.1511 train_time:419979ms step_avg:100.00ms -step:4300/20000 train_loss:2.0566 train_time:429895ms step_avg:99.98ms -step:4400/20000 train_loss:2.0267 train_time:439861ms step_avg:99.97ms -step:4500/20000 train_loss:2.1164 train_time:449828ms step_avg:99.96ms -step:4500/20000 val_loss:2.0388 val_bpb:1.2075 train_time:449853ms step_avg:99.97ms -step:4600/20000 train_loss:1.8345 train_time:459799ms step_avg:99.96ms -step:4700/20000 train_loss:2.2178 train_time:469706ms step_avg:99.94ms -step:4800/20000 train_loss:2.4194 train_time:479665ms step_avg:99.93ms -step:4900/20000 train_loss:2.0402 train_time:489639ms step_avg:99.93ms -step:5000/20000 train_loss:2.0891 train_time:499610ms step_avg:99.92ms -step:5000/20000 val_loss:2.0117 val_bpb:1.1915 train_time:499637ms step_avg:99.93ms -step:5100/20000 train_loss:2.1118 train_time:509574ms step_avg:99.92ms -step:5200/20000 train_loss:2.0281 train_time:519474ms step_avg:99.90ms -step:5300/20000 train_loss:1.9924 train_time:529438ms step_avg:99.89ms -swa:start step:5350 -step:5400/20000 train_loss:2.0311 train_time:539475ms step_avg:99.90ms -step:5500/20000 train_loss:2.0004 train_time:549484ms step_avg:99.91ms -step:5500/20000 val_loss:1.9834 val_bpb:1.1747 train_time:549536ms step_avg:99.92ms -step:5600/20000 train_loss:1.9374 train_time:559498ms step_avg:99.91ms -step:5700/20000 train_loss:1.9898 train_time:569462ms step_avg:99.91ms -step:5800/20000 train_loss:1.9719 train_time:579486ms step_avg:99.91ms -step:5900/20000 train_loss:1.8830 train_time:589542ms step_avg:99.92ms -step:6000/20000 train_loss:1.9237 train_time:599555ms step_avg:99.93ms -step:6000/20000 val_loss:1.9583 val_bpb:1.1598 train_time:599622ms step_avg:99.94ms -step:6005/20000 val_loss:1.9583 val_bpb:1.1598 train_time:600115ms step_avg:99.94ms -stopping_early: wallclock_cap train_time:600115ms step:6005/20000 -peak memory allocated: 20841 MiB reserved: 21060 MiB -swa:applying averaged 14 checkpoints -Serialized model: 98437419 bytes -Code size: 58616 bytes -Total submission size: 98496035 bytes -awq:calibrating alpha=0.5 -awq:scaled 61 layers -Serialized model int6+zstd: 15458563 bytes -Total submission size int8+zlib: 15517179 bytes -awq:unscaled 61 layers after dequant -final_eval_mode:sliding_window stride:64 batch_seqs:64 -final_int8_zlib_roundtrip val_loss:1.9408 val_bpb:1.1494 eval_time:180159ms -final_int8_zlib_roundtrip_exact val_loss:1.94077623 val_bpb:1.14944004 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt deleted file mode 100644 index 69e4a3c2e0..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/logs/submission_exp18_awq-cyclic-relusq-11Lshared_8xH100_seed44.txt +++ /dev/null @@ -1,1521 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Tue Mar 24 11:54:04 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 47C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | -| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 47C P0 134W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 48C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | -| N/A 38C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | -| N/A 47C P0 128W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 36C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:44 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9280 val_bpb:4.1031 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9287 train_time:149ms step_avg:149.48ms -step:2/20000 train_loss:8.4639 train_time:222ms step_avg:111.22ms -step:3/20000 train_loss:7.8194 train_time:320ms step_avg:106.51ms -step:4/20000 train_loss:7.2740 train_time:418ms step_avg:104.46ms -step:5/20000 train_loss:6.9708 train_time:515ms step_avg:102.91ms -step:6/20000 train_loss:6.8442 train_time:613ms step_avg:102.11ms -step:7/20000 train_loss:6.7332 train_time:710ms step_avg:101.45ms -step:8/20000 train_loss:6.6899 train_time:808ms step_avg:100.97ms -step:9/20000 train_loss:6.4768 train_time:904ms step_avg:100.45ms -step:10/20000 train_loss:6.1476 train_time:1001ms step_avg:100.06ms -step:100/20000 train_loss:3.1518 train_time:9962ms step_avg:99.62ms -step:200/20000 train_loss:2.3792 train_time:19971ms step_avg:99.86ms -step:300/20000 train_loss:2.5375 train_time:30002ms step_avg:100.01ms -step:400/20000 train_loss:2.4066 train_time:40063ms step_avg:100.16ms -step:500/20000 train_loss:2.3937 train_time:50067ms step_avg:100.13ms -step:500/20000 val_loss:2.3569 val_bpb:1.3959 train_time:50091ms step_avg:100.18ms -step:600/20000 train_loss:2.3337 train_time:60134ms step_avg:100.22ms -step:700/20000 train_loss:2.3503 train_time:70193ms step_avg:100.28ms -step:800/20000 train_loss:2.2417 train_time:80256ms step_avg:100.32ms -step:900/20000 train_loss:2.1287 train_time:90313ms step_avg:100.35ms -step:1000/20000 train_loss:2.2758 train_time:100310ms step_avg:100.31ms -step:1000/20000 val_loss:2.2311 val_bpb:1.3214 train_time:100336ms step_avg:100.34ms -step:1100/20000 train_loss:2.3322 train_time:110366ms step_avg:100.33ms -step:1200/20000 train_loss:2.3606 train_time:120403ms step_avg:100.34ms -step:1300/20000 train_loss:2.1112 train_time:130456ms step_avg:100.35ms -step:1400/20000 train_loss:2.1882 train_time:140478ms step_avg:100.34ms -step:1500/20000 train_loss:2.2263 train_time:150441ms step_avg:100.29ms -step:1500/20000 val_loss:2.1920 val_bpb:1.2982 train_time:150469ms step_avg:100.31ms -step:1600/20000 train_loss:2.0377 train_time:160472ms step_avg:100.30ms -step:1700/20000 train_loss:2.1060 train_time:170501ms step_avg:100.29ms -step:1800/20000 train_loss:2.1298 train_time:180522ms step_avg:100.29ms -step:1900/20000 train_loss:2.1008 train_time:190493ms step_avg:100.26ms -step:2000/20000 train_loss:2.0515 train_time:200504ms step_avg:100.25ms -step:2000/20000 val_loss:2.1173 val_bpb:1.2540 train_time:200531ms step_avg:100.27ms -step:2100/20000 train_loss:2.0373 train_time:210546ms step_avg:100.26ms -step:2200/20000 train_loss:2.1344 train_time:220560ms step_avg:100.25ms -step:2300/20000 train_loss:2.1168 train_time:230569ms step_avg:100.25ms -step:2400/20000 train_loss:2.0723 train_time:240545ms step_avg:100.23ms -step:2500/20000 train_loss:2.1878 train_time:250547ms step_avg:100.22ms -step:2500/20000 val_loss:2.1235 val_bpb:1.2577 train_time:250575ms step_avg:100.23ms -step:2600/20000 train_loss:2.1224 train_time:260544ms step_avg:100.21ms -step:2700/20000 train_loss:2.1155 train_time:270517ms step_avg:100.19ms -step:2800/20000 train_loss:2.1707 train_time:280521ms step_avg:100.19ms -step:2900/20000 train_loss:2.0416 train_time:290459ms step_avg:100.16ms -step:3000/20000 train_loss:2.1755 train_time:300453ms step_avg:100.15ms -step:3000/20000 val_loss:2.1071 val_bpb:1.2479 train_time:300479ms step_avg:100.16ms -step:3100/20000 train_loss:2.0540 train_time:310500ms step_avg:100.16ms -step:3200/20000 train_loss:2.1885 train_time:320480ms step_avg:100.15ms -step:3300/20000 train_loss:2.0829 train_time:330417ms step_avg:100.13ms -step:3400/20000 train_loss:2.0306 train_time:340422ms step_avg:100.12ms -step:3500/20000 train_loss:2.1900 train_time:350415ms step_avg:100.12ms -step:3500/20000 val_loss:2.0899 val_bpb:1.2377 train_time:350441ms step_avg:100.13ms -step:3600/20000 train_loss:2.0978 train_time:360415ms step_avg:100.12ms -step:3700/20000 train_loss:2.1007 train_time:370408ms step_avg:100.11ms -step:3800/20000 train_loss:2.0762 train_time:380332ms step_avg:100.09ms -step:3900/20000 train_loss:2.0818 train_time:390308ms step_avg:100.08ms -step:4000/20000 train_loss:1.9747 train_time:400290ms step_avg:100.07ms -step:4000/20000 val_loss:2.0678 val_bpb:1.2247 train_time:400315ms step_avg:100.08ms -step:4100/20000 train_loss:2.0182 train_time:410281ms step_avg:100.07ms -step:4200/20000 train_loss:2.1542 train_time:420258ms step_avg:100.06ms -step:4300/20000 train_loss:2.0556 train_time:430192ms step_avg:100.04ms -step:4400/20000 train_loss:2.0332 train_time:440177ms step_avg:100.04ms -step:4500/20000 train_loss:2.1193 train_time:450138ms step_avg:100.03ms -step:4500/20000 val_loss:2.0431 val_bpb:1.2101 train_time:450164ms step_avg:100.04ms -step:4600/20000 train_loss:1.8407 train_time:460123ms step_avg:100.03ms -step:4700/20000 train_loss:2.2264 train_time:470035ms step_avg:100.01ms -step:4800/20000 train_loss:2.4262 train_time:480013ms step_avg:100.00ms -step:4900/20000 train_loss:2.0455 train_time:489990ms step_avg:100.00ms -step:5000/20000 train_loss:2.0949 train_time:499981ms step_avg:100.00ms -step:5000/20000 val_loss:2.0153 val_bpb:1.1936 train_time:500007ms step_avg:100.00ms -step:5100/20000 train_loss:2.1149 train_time:509962ms step_avg:99.99ms -step:5200/20000 train_loss:2.0314 train_time:519881ms step_avg:99.98ms -step:5300/20000 train_loss:1.9922 train_time:529849ms step_avg:99.97ms -swa:start step:5350 -step:5400/20000 train_loss:2.0348 train_time:539907ms step_avg:99.98ms -step:5500/20000 train_loss:2.0003 train_time:549943ms step_avg:99.99ms -step:5500/20000 val_loss:1.9868 val_bpb:1.1767 train_time:550013ms step_avg:100.00ms -step:5600/20000 train_loss:1.9387 train_time:559996ms step_avg:100.00ms -step:5700/20000 train_loss:1.9929 train_time:569977ms step_avg:100.00ms -step:5800/20000 train_loss:1.9749 train_time:580010ms step_avg:100.00ms -step:5900/20000 train_loss:1.8845 train_time:590024ms step_avg:100.00ms -step:6000/20000 train_loss:1.9269 train_time:600029ms step_avg:100.00ms -step:6000/20000 val_loss:1.9619 val_bpb:1.1619 train_time:600099ms step_avg:100.02ms -stopping_early: wallclock_cap train_time:600099ms step:6000/20000 -peak memory allocated: 20841 MiB reserved: 21060 MiB -swa:applying averaged 14 checkpoints -Serialized model: 98437419 bytes -Code size: 58616 bytes -Total submission size: 98496035 bytes -awq:calibrating alpha=0.5 -awq:scaled 61 layers -Serialized model int6+zstd: 15367341 bytes -Total submission size int8+zlib: 15425957 bytes -awq:unscaled 61 layers after dequant -final_eval_mode:sliding_window stride:64 batch_seqs:64 -final_int8_zlib_roundtrip val_loss:1.9460 val_bpb:1.1526 eval_time:179947ms -final_int8_zlib_roundtrip_exact val_loss:1.94603275 val_bpb:1.15255326 diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh deleted file mode 100755 index 8f44293924..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/run.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash -# ============================================================ -# SUBMISSION: exp18 AWQ + cyclic momentum + relu_sq + 11L shared -# 8×H100 SXM, 3 seeds, full sliding window eval -# ============================================================ -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -EXP_NAME="submission_exp18_awq-cyclic-relusq-11Lshared_8xH100" -LOG_DIR="records/h100_experiments/${EXP_NAME}/logs" - -cd /workspace/parameter-golf -mkdir -p "${LOG_DIR}" - -# --- Architecture --- -export NUM_LAYERS=11 -export UNIQUE_LAYERS=10 -export MODEL_DIM=512 -export NUM_HEADS=8 -export NUM_KV_HEADS=4 -export MLP_MULT=3.0 -export MLP_ACTIVATION=relu_sq -export VOCAB_SIZE=1024 -export TIE_EMBEDDINGS=1 -export LOGIT_SOFTCAP=30.0 -export BIGRAM_VOCAB_SIZE=10240 -export BIGRAM_DIM=128 - -# --- Training (8×H100 scale) --- -export ITERATIONS=20000 -export WARMUP_STEPS=20 -export WARMDOWN_ITERS=3500 -export TRAIN_BATCH_TOKENS=786432 -export TRAIN_SEQ_LEN=2048 -export MAX_WALLCLOCK_SECONDS=600 - -# --- Optimizer --- -export MATRIX_LR=0.025 -export SCALAR_LR=0.025 -export TIED_EMBED_LR=0.035 -export MUON_MOMENTUM=0.99 -export MUON_MOMENTUM_WARMUP_START=0.92 -export MUON_MOMENTUM_WARMUP_STEPS=1500 -export MOMENTUM_CYCLIC=1 -export MOMENTUM_MIN=0.85 -export MOMENTUM_MAX=0.95 -export MOMENTUM_CYCLE_PERIOD=50 -export GRAD_CLIP_NORM=0.3 -export WEIGHT_DECAY=0.04 - -# --- SWA --- -export SWA_ENABLED=1 -export SWA_START_FRAC=0.2 -export SWA_EVERY=50 - -# --- Validation & Eval --- -export VAL_LOSS_EVERY=500 -export VAL_BATCH_SIZE=524288 -export TRAIN_LOG_EVERY=100 -export EVAL_STRIDE=64 -export EVAL_BATCH_SEQS=64 - -# --- AWQ --- -export AWQ_ENABLED=1 -export AWQ_ALPHA=0.5 - -# --- Run 3 seeds --- -for SEED in 42 43 44; do - export SEED - export RUN_ID="${EXP_NAME}_seed${SEED}" - echo "============================================" - echo "=== SEED ${SEED} ===" - echo "============================================" - torchrun --nproc_per_node=8 "${SCRIPT_DIR}/train_gpt.py" 2>&1 | tee "${LOG_DIR}/${EXP_NAME}_seed${SEED}.log" - echo "=== SEED ${SEED} COMPLETE ===" - echo "" -done - -echo "=== ALL 3 SEEDS COMPLETE ===" diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh deleted file mode 100755 index 3a7fa0e4e1..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/setup_and_run.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash -# ============================================================ -# FULL SETUP + RUN for 8×H100 submission -# Usage: Pass SSH host and port as arguments -# ./setup_and_run.sh -# ============================================================ -set -euo pipefail - -HOST="${1:?Usage: $0 }" -PORT="${2:?Usage: $0 }" -SSH="ssh -o StrictHostKeyChecking=no -p ${PORT} root@${HOST}" -SCP="scp -P ${PORT}" -LOCAL_BASE="/Users/sidhantthole/Documents/llama_index_exp/openaigolf/parameter-golf" -REMOTE_BASE="/workspace/parameter-golf" -EXP="submission_exp18_awq-cyclic-relusq-11Lshared_8xH100" - -echo "=== Step 1: Test connection ===" -$SSH "echo 'Connected!' && nvidia-smi --query-gpu=name,memory.total --format=csv,noheader" - -echo "=== Step 2: Clone repo + install deps ===" -$SSH "cd /workspace && rm -rf parameter-golf && git clone https://github.com/openai/parameter-golf.git && pip install --break-system-packages -q zstandard" - -echo "=== Step 3: Start data download in background ===" -$SSH "cd ${REMOTE_BASE} && nohup python3 data/cached_challenge_fineweb.py --variant sp1024 > /tmp/data_download.out 2>&1 &" - -echo "=== Step 4: Copy experiment files while data downloads ===" -$SSH "mkdir -p ${REMOTE_BASE}/records/h100_experiments" -$SCP -r "${LOCAL_BASE}/records/h100_experiments/${EXP}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/" -$SSH "chmod +x ${REMOTE_BASE}/records/h100_experiments/${EXP}/run.sh" - -echo "=== Step 5: Wait for data download ===" -while true; do - COUNT=$($SSH "ls ${REMOTE_BASE}/data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l" 2>/dev/null || echo "0") - echo " Data shards: ${COUNT}/80" - if [ "$COUNT" -ge 80 ]; then - break - fi - sleep 10 -done -echo "Data download complete!" - -echo "=== Step 6: Launch submission (3 seeds) ===" -$SSH "nohup bash ${REMOTE_BASE}/records/h100_experiments/${EXP}/run.sh > /tmp/submission.out 2>&1 &" -echo "Submission launched!" - -echo "=== Step 7: Syncing records every 10 seconds ===" -while true; do - rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/${EXP}/" "${LOCAL_BASE}/records/h100_experiments/${EXP}/" 2>/dev/null - rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/logs/" "${LOCAL_BASE}/logs/" 2>/dev/null - - # Check if still running - RUNNING=$($SSH "ps aux | grep train_gpt | grep -v grep | wc -l" 2>/dev/null || echo "0") - if [ "$RUNNING" -eq 0 ]; then - # Final sync - rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/records/h100_experiments/${EXP}/" "${LOCAL_BASE}/records/h100_experiments/${EXP}/" 2>/dev/null - rsync -az -e "ssh -p ${PORT}" "root@${HOST}:${REMOTE_BASE}/logs/" "${LOCAL_BASE}/logs/" 2>/dev/null - echo "=== ALL SEEDS COMPLETE ===" - break - fi - sleep 10 -done diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json deleted file mode 100644 index 08b8e425cc..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "SPThole", - "github_id": "SPThole", - "name": "AWQ + Cyclic Momentum + ReLU² + 11L Shared", - "blurb": "Activation-Aware Weight Quantization (AWQ) closes the int5/int6 quant gap from 0.027 to 0.010 bpb. Cyclic Muon momentum (0.85-0.95 triangle wave) escapes sharp minima. ReLU² for sparser MLPs. 11 layers with 10 unique weights. Emerged from 21+ systematic experiments on 1×H100/A40 before scaling to 8×H100.", - "date": "2026-03-25T00:00:00Z", - "val_loss": 1.94293025, - "val_bpb": 1.15071578, - "bytes_total": 15465308, - "bytes_code": 58616 -} diff --git a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py b/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py deleted file mode 100644 index 5645ab6700..0000000000 --- a/records/track_10min_16mb/2025-03-24_AWQ_CyclMom_11L_shared/train_gpt.py +++ /dev/null @@ -1,1340 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - mlp_activation = os.environ.get("MLP_ACTIVATION", "relu_sq") - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - momentum_cyclic = bool(int(os.environ.get("MOMENTUM_CYCLIC", "1"))) - momentum_min = float(os.environ.get("MOMENTUM_MIN", 0.85)) - momentum_max = float(os.environ.get("MOMENTUM_MAX", 0.95)) - momentum_cycle_period = int(os.environ.get("MOMENTUM_CYCLE_PERIOD", 50)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - unique_layers: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - # Layer sharing: create unique_layers physical blocks, map num_layers virtual layers to them - n_unique = unique_layers if unique_layers > 0 else num_layers - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(n_unique) - ] - ) - # Mapping from virtual layer index → physical block index - # Last virtual layer shares with the last physical block - self._layer_map = list(range(min(n_unique, num_layers))) - while len(self._layer_map) < num_layers: - self._layer_map.append(n_unique - 1) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self._layer_map) # virtual depth, not physical blocks - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[self._layer_map[i]](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - unique_layers=args.unique_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - if step < args.muon_momentum_warmup_steps: - # Linear warmup phase - frac = step / args.muon_momentum_warmup_steps if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - elif args.momentum_cyclic: - # Cyclic momentum: triangle wave between momentum_min and momentum_max - period = args.momentum_cycle_period * 2 - pos = (step % period) / period - if pos < 0.5: - muon_momentum = args.momentum_min + (args.momentum_max - args.momentum_min) * (pos * 2) - else: - muon_momentum = args.momentum_max - (args.momentum_max - args.momentum_min) * ((pos - 0.5) * 2) - else: - muon_momentum = args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # AWQ: Activation-aware weight scaling before quantization - awq_alpha = float(os.environ.get("AWQ_ALPHA", 0.5)) - awq_enabled = bool(int(os.environ.get("AWQ_ENABLED", "1"))) - if awq_enabled: - log0(f"awq:calibrating alpha={awq_alpha}") - # Step 1: Collect per-channel activation magnitudes via hooks - act_stats: dict[str, Tensor] = {} - hooks = [] - def make_hook(name): - def hook_fn(module, input, output): - x = input[0] if isinstance(input, tuple) else input - if x.ndim >= 2: - # Mean absolute activation per channel (last dim) - s = x.detach().float().abs().mean(dim=tuple(range(x.ndim - 1))) - if name in act_stats: - act_stats[name] = act_stats[name] + s - else: - act_stats[name] = s - return hook_fn - - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)): - hooks.append(module.register_forward_hook(make_hook(name))) - - # Step 2: Run calibration batches (use val data, 4 batches) - base_model.eval() - n_calib_batches = 4 - calib_seq_len = args.train_seq_len - calib_batch_seqs = 16 - calib_tokens_per_batch = calib_batch_seqs * calib_seq_len - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for cb in range(n_calib_batches): - start = cb * calib_tokens_per_batch - end = start + calib_tokens_per_batch + 1 - if end > val_tokens.numel(): - break - local = val_tokens[start:end].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, calib_seq_len) - base_model.forward_logits(x) - - for h in hooks: - h.remove() - - # Normalize activation stats - for name in act_stats: - act_stats[name] = act_stats[name] / n_calib_batches - - # Step 3: Scale weight columns by s^alpha - awq_scales = {} - with torch.no_grad(): - for name, module in base_model.named_modules(): - if isinstance(module, (nn.Linear, CastedLinear)) and name in act_stats: - s = act_stats[name].to(module.weight.device) - s = s.clamp_min(1e-6) - scale = s.pow(awq_alpha) - # Scale weight columns (input channels) - if module.weight.shape[1] == scale.shape[0]: - module.weight.data = module.weight.data * scale.unsqueeze(0) - awq_scales[name] = scale.cpu() - # Store inverse scale to apply to inputs at inference - # We'll fold this into the state dict - - log0(f"awq:scaled {len(awq_scales)} layers") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - # Store AWQ inverse scales in the state dict for inference compensation - if awq_enabled and awq_scales: - for lname, scale in awq_scales.items(): - sd_cpu[f"_awq_inv_scale.{lname}"] = (1.0 / scale).to(torch.float16) - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - # AWQ: undo column scaling after dequantization - if awq_enabled: - awq_inv_keys = [k for k in deq_state if k.startswith("_awq_inv_scale.")] - for inv_key in awq_inv_keys: - inv_scale = deq_state.pop(inv_key).float() - layer_name = inv_key.replace("_awq_inv_scale.", "") - weight_key = layer_name + ".weight" - if weight_key in deq_state: - # Undo: W_orig = W_scaled * inv_scale (per column) - deq_state[weight_key] = deq_state[weight_key].float() * inv_scale.unsqueeze(0) - deq_state[weight_key] = deq_state[weight_key].to(dtype=sd_cpu[weight_key].dtype) - log0(f"awq:unscaled {len(awq_inv_keys)} layers after dequant") - # Remove any remaining AWQ keys before loading - deq_state = {k: v for k, v in deq_state.items() if not k.startswith("_awq_")} - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned From 0943b8918a3aef612a4ce8dca0a0f4a25867074e Mon Sep 17 00:00:00 2001 From: SPThole Date: Tue, 14 Apr 2026 00:06:45 +0530 Subject: [PATCH 09/10] adding summaries --- .../STRUCTURED_EXPSUM.md | 1750 +++++++++++++++++ .../index.html | 1465 ++++++++++++++ 2 files changed, 3215 insertions(+) create mode 100644 records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md create mode 100644 records/track_non_record_16mb/parameter-golf-experimentations-summary/index.html diff --git a/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md b/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md new file mode 100644 index 0000000000..a87a65f8e4 --- /dev/null +++ b/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md @@ -0,0 +1,1750 @@ +# Structured Experiment Summary + +> **Competition**: OpenAI Parameter Golf +> **Objective**: Minimize validation loss (bits-per-byte, bpb) under a 16MB artifact constraint within 10-minute training on 8×H100 +> **Total experiments**: 119+ +> **Date range**: Early 2026 — 2026-04-13 +> **Best result**: **1.0744 legal_ttt bpb** (ImprovedParallelResiduals, community PR #1523, 8×H100) + +--- + +# Table of Contents + +1. [Global Leaderboard](#1-global-leaderboard) +2. [Phase 3a: Baseline Experiments (exp00–exp18)](#2-phase-3a) +3. [Phase 3b-Part1: Systematic Ablations (exp27b–exp33b)](#3-phase-3b-part1) +4. [Phase 3b-Part2: LR Fix Era (exp34b–exp48b)](#4-phase-3b-part2) +5. [Phase 3b-Part3: Simplification + XSA (exp53b–clean_54b)](#5-phase-3b-part3) +6. [Phase 3.5: 8×H100 Simulation (exp60–exp80)](#6-phase-35) +7. [Phase 3.6: Diagnostic-Driven Era (exp83–exp87)](#7-phase-36) +8. [Phase 3b-Muon: Parallel Muon Optimizer (exp70_parallel_muon–exp91)](#8-phase-3b-muon) +9. [Phase 3c: Architecture Rewrite + Meta-TTT (exp92–exp109)](#9-phase-3c) +10. [Phase 3c-Community: Community SOTA (SP8192+)](#10-community-sota) +11. [Phase 3c-Frontier: Pushing Past Community (exp110–exp119)](#11-frontier) +12. [Misc: Co-occurrence QK Init](#12-misc) +13. [Known Issues](#13-known-issues) +14. [Key Learnings by Phase](#14-key-learnings) +15. [TLDR: Top 20 Learnings](#15-top-20) +16. [Appendix: Full Lineage Trees](#16-appendix) + +--- + +# 1. Global Leaderboard + +## 1.1 All-Time Best (by legal_ttt bpb) + +| Rank | Experiment | Date | legal_ttt | val_bpb | int6_bpb | Artifact | Hardware | Source | +|:----:|-----------|:----:|:---------:|:-------:|:--------:|:--------:|:--------:|:------:| +| **1** | ImprovedParallelResiduals | 2026-04-11 | **1.0744** | — | — | 15.96 MB | 8×H100 | Community PR #1523 | +| 2 | WiderEmb_TapInV6_TTT | 2026-04-10 | 1.0788 | 1.0813 | 1.0980 | ~16 MB | 8×H100 | Community | +| 3 | SP8192_3LayerRecur | 2026-04-09 | 1.0808 | 1.0873 | 1.0997 | ~16 MB | 8×H100 | Community | +| 4 | exp101 | 2026-04 | 1.11588 | 1.1352 | 1.13930 | ~16 MB | 8×H100 | Our work | +| 5 | exp95 | 2026-03 | 1.1169 | 1.1363 | — | ~16 MB | 8×H100 | Our work | +| 6 | exp74 | 2026-03 | — | 1.1539 | 1.1685 | 15.86 MB | 1×H100 (sim) | Our work | +| 7 | exp54b | 2026-03 | — | 1.2642 | 1.2708 | 15.54 MB | 1×H100 | Our work | + +## 1.2 Milestone Timeline + +| Date | Best legal_ttt | Experiment | Key Innovation | +|------|:--------------:|-----------|----------------| +| Early | 1.3389 (quant) | exp00 baseline | Starting point | +| Early | 1.3145 (quant) | exp09 | Step count + loss masking | +| Phase 3b | 1.2708 (quant) | exp54b | LR fix + simplification | +| 2026-03 | 1.1456 (sliding) | exp74 | Partial RoPE + diagnostics | +| 2026-03 | 1.1169 | exp95 | Meta-TTT + size optimization | +| 2026-04-04 | 1.11588 | exp101 | Position-conditional bigram | +| 2026-04-09 | 1.0808 | SP8192_3LayerRecur (community) | SP8192 + depth recurrence | +| 2026-04-10 | 1.0788 | WiderEmb_TapInV6 (community) | Wider loop + Tap-In V6 | +| **2026-04-11** | **1.0744** | **ImprovedParallelResiduals (community, PR #1523)** | **Cross-lane parallel residuals** | + +--- + +# 2. Phase 3a: Baseline Experiments (exp00–exp18) + +> **Hardware**: 1×H100 (or A100), 600s wallclock +> **Base**: exp27 (modded-nanogpt reference) +> **Best result**: exp09/exp13 — quant bpb **1.3145**, artifact 14.5 MB + +## 2.1 Config Constants + +| Parameter | Value | +|-----------|-------| +| Model dim | 512 | +| Num layers | 11 (10 unique, layer sharing) | +| Attention | GQA, 8 Q-heads, 4 KV-heads, head_dim=64 | +| MLP | LeakyReLU², mlp_mult=3.0 (hidden=1536) | +| Vocab size | 1024 (SentencePiece BPE) | +| Seq length | 2048 | +| Softcap | 30 | +| Optimizer | Muon (matrices) + Adam (scalars) | +| Momentum | Cyclic 0.85–0.95, period=50 | +| Grad accum | 2 | +| SWA | Start at 20% training, every 100 steps | +| AWQ alpha | 0.6 | +| Quantization | int6 + zstd | +| Wallclock | 600s (10 min) | +| Total params | ~25.5M | + +## 2.2 Leaderboard + +| Rank | Exp | Name | Quant BPB | Raw BPB | Artifact | Under 16MB? | +|:----:|:---:|------|:---------:|:-------:|:--------:|:-----------:| +| **1** | **54b** | xsa-zstd-ckfix | **1.2708** | 1.2642 | 15.54 MB | **Yes** | +| 2 | 53b | lean-combo (v5) | 1.2720 | 1.2640 | 15.19 MB | Yes | +| 3 | community | SOTA leaderboard (1xH100) | 1.2825 | 1.2501 | 13.06 MB | Yes | +| 4 | 48b | 10blocks-depth | 1.2930 | 1.2870 | 14.59 MB | Yes | +| 5 | 42b | revive-block9 | 1.2969 | 1.2867 | 14.01 MB | Yes | +| 6 | 34b | lr-schedule-fix | 1.2990 | 1.2891 | 15.13 MB | Yes | +| 7 | 39b | swa-tuning | 1.2942 | 1.2875 | 14.55 MB | Yes | +| 8 | 30b | combo | 1.3156 | 1.2983 | 15.05 MB | Yes | +| 9 | 29b | lossweight-typemb | 1.3176 | 1.3007 | 15.75 MB | Yes | +| 10 | 27b | resid-norm | 1.3197 | 1.3000 | 15.30 MB | Yes | +| 11 | 09 | padignore-wordboost | 1.3145 | 1.2974 | 14.5 MB | Yes | +| 11 | 13 | multihead-gate-bigram | 1.3145 | 1.2974 | 14.5 MB | Yes | +| 13 | 10 | trigram-unigram | 1.3151 | 1.2956 | 15.6 MB | Yes | +| 14 | 06 | swa-awq-accum2 | 1.3161 | 1.2982 | 15.7 MB | Yes | +| 15 | 07 | tighter-swa-awq | 1.3164 | 1.2978 | 15.5 MB | Yes | +| 16 | 05 | grad-accum4 | 1.3181 | 1.3001 | 15.8 MB | Yes | +| 17 | 12 | trigram64-awq06 | 1.3222 | 1.2969 | 15.1 MB | Yes | +| 18 | 08 | ctx-freq-bias | 1.3225 | 1.3014 | 15.0 MB | Yes | +| 19 | 18 | separate-trigram64 | 1.3247 | 1.2995 | 15.0 MB | Yes | +| 20 | 11 | trigram-slim-awq07 | 1.3259 | 1.2994 | 14.6 MB | Yes | +| 21 | 15 | engram-3order | 1.3260 | 1.2995 | 14.6 MB | Yes | +| 22 | 14 | engram-multiorder | 1.3338 | 1.3056 | 15.0 MB | Yes | +| 23 | 00 | baseline-rerun | 1.3389 | 1.3166 | 14.7 MB | Yes | +| 24 | 02 | speed-bigramfp16-awq | 1.3429 | 1.3200 | 17.3 MB | **No** | +| 25 | 04 | no-cyclic-momentum | 1.3489 | 1.3230 | 15.8 MB | Yes | +| 26 | 16 | jepa-aux | 1.3526 | 1.3194 | 14.5 MB | Yes | +| 27 | 17 | byte-engram | 1.3527 | 1.3201 | 14.6 MB | Yes | +| 28 | 01d | xsa-only | 1.3568 | 1.3294 | 14.8 MB | Yes | +| 29 | 03 | qat-ste | 1.3684 | 1.3365 | 15.7 MB | Yes | +| — | 01b | ln-scale-only | — | — | — | Not run | +| — | 01c | ema-only | — | — | — | Not run | + +## 2.3 Detailed Experiment Cards + +### Exp00 — Baseline Rerun (exp27) + +| Field | Value | +|-------|-------| +| **Folder** | `exp00_baseline-rerun_exp27/` | +| **Based on** | exp27 (modded-nanogpt reference) | +| **Status** | Done | +| **Steps** | 1,084 | +| **Raw BPB** | 1.3166 | +| **Quant BPB** | 1.3389 | +| **Artifact** | 14.7 MB | + +**Details**: Control experiment. Runs the exp27 reference config on A100 with 600s wallclock, grad_accum=8, GQA 8Q/4KV heads, 25.5M params, int6+zstd quantization. + +**Observations**: ~553ms/step. Quant gap = 1.7%. Bigram.proj has worst quantization error. Still improving at end — more steps would help. + +--- + +### Exp01b — LN Scale Only (ablation) + +| Field | Value | +|-------|-------| +| **Folder** | `exp01b_ln-scale-only_from-exp27/` | +| **Based on** | exp27 | +| **Status** | Not run (empty log) | + +**Details**: Isolated test of `1/√(layer+1)` layer-norm damping without partial RoPE. + +--- + +### Exp01c — EMA Only (ablation) + +| Field | Value | +|-------|-------| +| **Folder** | `exp01c_ema-only_from-exp27/` | +| **Based on** | exp27 | +| **Status** | Not run (empty log) | + +**Details**: Isolated test of EMA weight averaging without other changes. + +--- + +### Exp01d — XSA Only (ablation) + +| Field | Value | +|-------|-------| +| **Folder** | `exp01d_xsa-only_from-exp27/` | +| **Based on** | exp27 | +| **Status** | Done | +| **Steps** | 1,017 | +| **Raw BPB** | 1.3294 | +| **Quant BPB** | 1.3568 | +| **Artifact** | 14.8 MB | + +**Details**: Cross-sequence attention on last 4 layers, keeping SWA and full RoPE. Slower per step (~590ms vs 553ms baseline), fewer steps completed. + +**Observations**: XSA alone hurts — slower steps + no benefit = regression. The speed cost outweighs any representational gain. + +--- + +### Exp02 — Speed + Bigram FP16 + Better AWQ + +| Field | Value | +|-------|-------| +| **Folder** | `exp02_speed-bigramfp16-awq_from-exp00/` | +| **Based on** | exp00 | +| **Status** | Done | +| **Steps** | 1,080 | +| **Raw BPB** | 1.3200 | +| **Quant BPB** | 1.3429 | +| **Artifact** | 17.3 MB | + +**Details**: Three changes combined: (1) muon_backend_steps=4 (was 5), val_loss_every=200 for speed; (2) bigram.proj kept in FP16 instead of quantized; (3) per-category AWQ alphas (bigram=0.75, attn=0.6, mlp=0.5) with 16 calibration batches. + +**Observations**: FP16 bigram blows artifact to 17.3MB — over budget. Per-category AWQ alpha interesting but gains washed out. + +--- + +### Exp03 — QAT-STE (Quantization-Aware Training) + +| Field | Value | +|-------|-------| +| **Folder** | `exp03_qat-ste_from-exp00/` | +| **Based on** | exp00 | +| **Status** | Done | +| **Steps** | 1,035 | +| **Raw BPB** | 1.3365 | +| **Quant BPB** | 1.3684 | +| **Artifact** | 15.7 MB | + +**Details**: Fake-quantize weights during warmdown phase using Straight-Through Estimator. + +**Observations**: Worst result of the batch. QAT-STE adds overhead (fewer steps) and destabilizes training. SWA+AWQ already handles quant gap reasonably. + +--- + +### Exp04 — No Cyclic Momentum + +| Field | Value | +|-------|-------| +| **Folder** | `exp04_no-cyclic-momentum_from-exp00/` | +| **Based on** | exp00 | +| **Status** | Done | +| **Steps** | 1,084 | +| **Raw BPB** | 1.3230 | +| **Quant BPB** | 1.3489 | +| **Artifact** | 15.8 MB | + +**Details**: Fixed momentum=0.95 instead of cycling 0.85-0.95. + +**Observations**: Marginal degradation. Cyclic momentum is slightly helpful — oscillation may act as implicit regularization. + +--- + +### Exp05 — Grad Accum 4 (More Steps) + +| Field | Value | +|-------|-------| +| **Folder** | `exp05_grad-accum4_from-exp00/` | +| **Based on** | exp00 | +| **Status** | Done | +| **Steps** | 1,206 | +| **Raw BPB** | 1.3001 | +| **Quant BPB** | 1.3181 | +| **Artifact** | 15.8 MB | + +**Details**: Reduced grad_accum 8→4, doubling step count within wallclock. Effective batch halved from 524K→262K tokens. + +**Observations**: **Major breakthrough** — first sub-1.32 quant bpb. +122 more steps than baseline. + +--- + +### Exp06 — SWA + AWQ + Accum 2 + +| Field | Value | +|-------|-------| +| **Folder** | `exp06_swa-awq-accum2_from-exp05/` | +| **Based on** | exp05 | +| **Status** | Done | +| **Steps** | 1,219 | +| **Raw BPB** | 1.2982 | +| **Quant BPB** | 1.3161 | +| **Artifact** | 15.7 MB | + +**Details**: Pushed accum 4→2. SWA_EVERY=100, AWQ_ALPHA=0.6, WARMUP_STEPS=80. + +**Observations**: Continued improvement. Raw bpb breaks below 1.30 for first time. + +--- + +### Exp07 — Tighter SWA + AWQ + +| Field | Value | +|-------|-------| +| **Folder** | `exp07_tighter-swa-awq_from-exp06/` | +| **Based on** | exp06 | +| **Status** | Done | +| **Steps** | 1,220 | +| **Raw BPB** | 1.2978 | +| **Quant BPB** | 1.3164 | +| **Artifact** | 15.5 MB | + +**Details**: SWA_EVERY=150, AWQ_ALPHA=0.7. + +**Observations**: Reduced artifact ~200KB vs exp06. Quant bpb nearly identical. Sweet spot is SWA_EVERY=100. + +--- + +### Exp08 — Context Frequency Bias + +| Field | Value | +|-------|-------| +| **Folder** | `exp08_ctx-freq-bias_from-exp05/` | +| **Based on** | exp05 | +| **Status** | Done | +| **Steps** | 1,196 | +| **Raw BPB** | 1.3014 | +| **Quant BPB** | 1.3225 | +| **Artifact** | 15.0 MB | + +**Details**: Learned scalar bias `ctx_freq_bias * log(1 + count_in_context)` on logits. Exploits 77.7% token burstiness. +1 parameter. + +**Observations**: Small improvement over exp05 but redundant with what attention learns. Smallest artifact at 15.0MB. + +--- + +### Exp09 — Pad Ignore + Word-Start Boost ⭐ BEST + +| Field | Value | +|-------|-------| +| **Folder** | `exp09_padignore-wordboost_from-exp06/` | +| **Based on** | exp06 | +| **Status** | Done | +| **Steps** | 1,203 | +| **Raw BPB** | 1.2974 | +| **Quant BPB** | **1.3145** | +| **Artifact** | **14.5 MB** | + +**Details**: (1) `ignore_index=0` — skip pad tokens that waste gradient; (2) learned `word_start_boost` scalar that scales bigram at word boundaries (`▁` tokens). +1 parameter. + +**Observations**: **Best result.** Pad-ignore removes ~5-10% wasted compute. Word-start boost helps bigram focus on hardest prediction (word-initial at 5.1 bpb vs 2.9 bpb for repeats). + +--- + +### Exp10 — Trigram + Unigram Bias + +| Field | Value | +|-------|-------| +| **Folder** | `exp10_trigram-unigram_from-exp09/` | +| **Based on** | exp09 | +| **Status** | Done | +| **Steps** | 1,199 | +| **Raw BPB** | 1.2956 | +| **Quant BPB** | 1.3151 | +| **Artifact** | 15.6 MB | + +**Details**: Trigram hash table (10240×128) + learned `unigram_bias * log(freq)`. ~1.3M extra params. + +**Observations**: **Best raw bpb** (1.2956) but quant bpb slightly worse than exp09. Extra params compress less efficiently. + +--- + +### Exp11 — Trigram Slim (dim=48) + AWQ 0.7 + +| Field | Value | +|-------|-------| +| **Folder** | `exp11_trigram-slim-awq07_from-exp10/` | +| **Based on** | exp10 | +| **Status** | Done | +| **Steps** | 1,198 | +| **Raw BPB** | 1.2994 | +| **Quant BPB** | 1.3259 | +| **Artifact** | 14.6 MB | + +**Details**: Trigram embed dim 128→48 with `tri_proj` (48→512). AWQ alpha 0.6→0.7. + +**Observations**: Smaller artifact but significantly worse quant bpb. dim=48 too small, AWQ 0.7 too aggressive. Double regression. + +--- + +### Exp12 — Trigram 64-dim + AWQ 0.6 + +| Field | Value | +|-------|-------| +| **Folder** | `exp12_trigram64-awq06_from-exp10/` | +| **Based on** | exp10 | +| **Status** | Done | +| **Steps** | 1,195 | +| **Raw BPB** | 1.2969 | +| **Quant BPB** | 1.3222 | +| **Artifact** | 15.1 MB | + +**Details**: Middle-ground trigram dim=64, AWQ alpha=0.6. + +**Observations**: Better than exp11 but worse than exp09. Hash collisions in 10240-entry table too frequent for trigrams (1024³ combinations). + +--- + +### Exp13 — Multi-Head Gated Bigram ⭐ TIED BEST + +| Field | Value | +|-------|-------| +| **Folder** | `exp13_multihead-gate-bigram_from-exp09/` | +| **Based on** | exp09 | +| **Status** | Done | +| **Steps** | 1,201 | +| **Raw BPB** | 1.2974 | +| **Quant BPB** | **1.3145** | +| **Artifact** | **14.5 MB** | + +**Details**: K=2 independent hash functions (averaged, reduces collisions). Context gate: `sigmoid(gate_proj(tok_emb) + gate_bias)`. +513 extra params. + +**Observations**: **Tied with exp09 for best quant bpb**. Multi-head reduces collision 10%→1%, but impact neutralized. Base for exp14-26. + +--- + +### Exp14 — Engram Multi-Order (1-5gram) + +| Field | Value | +|-------|-------| +| **Folder** | `exp14_engram-multiorder_from-exp13/` | +| **Based on** | exp13 | +| **Status** | Done | +| **Steps** | 1,196 | +| **Raw BPB** | 1.3056 | +| **Quant BPB** | 1.3338 | +| **Artifact** | 15.0 MB | + +**Details**: 1-5gram with 2 hash heads each = 10 lookups/position. Shared 10240×128 table. 0 new params. + +**Observations**: **Significant regression**. Shared embeddings across n-gram orders = destructive interference. + +--- + +### Exp15 — Engram 3-Order (Orthogonal Subspaces) + +| Field | Value | +|-------|-------| +| **Folder** | `exp15_engram-3order_from-exp14/` | +| **Based on** | exp14 | +| **Status** | Done | +| **Steps** | 1,196 | +| **Raw BPB** | 1.2995 | +| **Quant BPB** | 1.3260 | +| **Artifact** | 14.6 MB | + +**Details**: 1-3gram × 2 heads = 6 lookups. Orthogonal subspace: unigram [0:42], bigram [42:84], trigram [84:128]. + +**Observations**: Better than exp14 but worse than exp09. Each subspace too small (~42 dims). + +--- + +### Exp16 — JEPA Auxiliary Loss + +| Field | Value | +|-------|-------| +| **Folder** | `exp16_jepa-aux_from-exp15/` | +| **Based on** | exp15 | +| **Status** | Done | +| **Steps** | 1,146 | +| **Raw BPB** | 1.3194 | +| **Quant BPB** | 1.3526 | +| **Artifact** | 14.5 MB | + +**Details**: Predictor MLP (512→128) predicts next position's engram embedding. MSE loss λ=0.1. ~65K extra training params. + +**Observations**: **Major regression**. Not true JEPA — uses fixed hash targets. MSE against hash embeddings provides adversarial gradient. ~524ms/step (slower). + +--- + +### Exp17 — Byte-Level Engram + +| Field | Value | +|-------|-------| +| **Folder** | `exp17_byte-engram_from-exp16/` | +| **Based on** | exp16 | +| **Status** | Done | +| **Steps** | 1,146 | +| **Raw BPB** | 1.3201 | +| **Quant BPB** | 1.3527 | +| **Artifact** | 14.6 MB | + +**Details**: ByteBoundaryEmbedding — cross-token byte bigram/trigram from 4-byte window. ~49K extra params. + +**Observations**: No improvement over exp16. Base (exp15+JEPA) too weak. + +--- + +### Exp18 — Separate Trigram Table (64-dim) + +| Field | Value | +|-------|-------| +| **Folder** | `exp18_separate-trigram64_from-exp13/` | +| **Based on** | exp13 | +| **Status** | Done | +| **Steps** | 1,194 | +| **Raw BPB** | 1.2995 | +| **Quant BPB** | 1.3247 | +| **Artifact** | 15.0 MB | + +**Details**: Separate 64-dim trigram table (10240×64) + 2-head hashing + projection, on top of exp13's 128-dim bigram. ~688K extra params. + +**Observations**: Marginal raw improvement but worse quant bpb (1.3247 vs 1.3145). Extra params don't compress well. + +## 2.4 Evolution Tree + +``` +exp27 (reference) + ├── exp00 (baseline rerun) ──────────────────────── quant 1.3389 + │ ├── exp01b (LN scale ablation) not run + │ ├── exp01c (EMA ablation) not run + │ ├── exp01d (XSA ablation) quant 1.3568 ✗ worse + │ ├── exp02 (speed+FP16 bigram) quant 1.3429 ✗ over 16MB + │ ├── exp03 (QAT-STE) quant 1.3684 ✗ worst + │ ├── exp04 (no cyclic momentum) quant 1.3489 ✗ worse + │ ├── exp05 (grad_accum=4) ────────────────────── quant 1.3181 ✓ breakthrough + │ │ ├── exp08 (context freq bias) quant 1.3225 + │ │ └── exp06 (SWA+AWQ+accum=2) ───────────── quant 1.3161 ✓ improved + │ │ ├── exp07 (tighter SWA/AWQ) quant 1.3164 + │ │ └── exp09 (pad ignore+word boost) ──── quant 1.3145 ⭐ BEST + │ │ ├── exp10 (trigram+unigram) ───── quant 1.3151 + │ │ │ ├── exp11 (trigram slim) quant 1.3259 ✗ + │ │ │ └── exp12 (trigram 64d) quant 1.3222 + │ │ └── exp13 (multihead gate bigram) quant 1.3145 ⭐ TIED BEST + │ │ ├── exp14 (engram 1-5gram) quant 1.3338 ✗ + │ │ │ └── exp15 (engram 3order) quant 1.3260 + │ │ │ └── exp16 (JEPA aux) quant 1.3526 ✗✗ + │ │ │ └── exp17 (byte engram) quant 1.3527 ✗✗ + │ │ ├── exp18 (separate trigram) quant 1.3247 + │ │ └── exp19-26 (phase 3b) ──── in progress +``` + +## 2.5 Lessons Learned (Phase 3a) + +1. **Step count is king**: Reducing grad_accum (2× more optimizer updates) > any architectural change. +2. **Don't fight the quantizer**: QAT-STE and FP16 bigram tried to address quantization directly — both failed. Better convergence naturally produces better-quantizing weights. +3. **N-gram tables have diminishing returns**: Bigram valuable (+0.02 bpb). Trigram marginal. Higher-order actively hurt. +4. **Hash collision reduction matters less than expected**: Multi-head hashing (exp13) reduces collisions 10%→1%, but quant bpb unchanged from exp09. +5. **Auxiliary losses are dangerous**: JEPA (exp16) caused biggest single regression. +6. **Simple targeted fixes beat complex architectures**: Pad-ignore (0 params) + word-start boost (1 param) > 688K trigram params. +7. **Quant gap is the real metric**: exp10 had best raw bpb but exp09 had best quant bpb. + +## 2.6 TL;DR (Phase 3a) + +**What worked**: +- Cutting `grad_accum` 8→4→2 = **2× more steps** = biggest single win (exp05→06) +- `ignore_index=0` to skip pad tokens in loss = free improvement (exp09) +- `word_start_boost` scalar for bigram at `▁` boundaries = +1 param, measurable gain (exp09) +- SWA every 100 steps + AWQ alpha=0.6 = good quant compression without hurting quality + +**What didn't work**: +- QAT-STE, FP16 bigram, XSA, fixed momentum (exp01d–04) = all worse than baseline +- Trigram/n-gram tables (exp10–12, 14–15, 18) = raw bpb improves but quant bpb regresses +- JEPA auxiliary loss with fixed hash targets (exp16–17) = worst regression of all + +**One-liner**: *More optimizer steps + smarter loss masking > fancy architecture, every time.* + +--- + +# 3. Phase 3b-Part1: Systematic Ablations (exp27b–exp33b) + +> **Hardware**: 1×H100, 600s wallclock +> **Base**: exp09 (chosen over exp13 — multi-head bigram added complexity without improving quant bpb) +> **Best result**: exp30b — quant bpb **1.3156**, artifact 15.05 MB + +## 3.1 Leaderboard + +| Rank | Exp | Name | Base | Quant BPB | Raw BPB | Size | Under 16MB? | +|:----:|:---:|------|:----:|:---------:|:-------:|:----:|:-----------:| +| **1** | **30b** | combo (resid-norm + loss-wt + type-emb) | exp09 | **1.3156** | 1.2983 | 15.05 MB | **Yes** | +| 2 | 33b | alternating RoPE + NTK | exp30b | 1.3145 | 1.2971 | 14.94 MB | Yes | +| 3 | 29b | loss-weight + token-type-emb | exp09 | 1.3176 | 1.3005 | 15.75 MB | Yes | +| 4 | 27b | resid-norm | exp09 | 1.3197 | ~1.300 | ~15.3 MB | Yes | +| 5 | 31b | RoPE base 50k | exp30b | 1.3206 | 1.2953 | 15.01 MB | Yes | +| 6 | 09 | padignore-wordboost (baseline) | exp06 | 1.3282 | 1.2974 | 14.5 MB | Yes | +| 7 | 32b | aux word-boundary loss | exp30b | 1.3424 | 1.3153 | 15.68 MB | Yes | +| — | 28b | perlayer quant | exp09 | N/A (analysis only) | — | 16.24 MB | **No** | + +## 3.2 Detailed Experiment Cards + +### Exp27b — Residual Stream Normalization (from exp09) + +| Field | Value | +|-------|-------| +| **Folder** | `exp27b_resid-norm_from-exp09/` | +| **Based on** | exp09 | +| **Steps** | ~1200 | +| **Raw BPB** | ~1.300 | +| **Quant BPB** | **1.3197** | +| **Artifact** | ~15.3 MB | +| **Extra params** | 0 | + +**Change**: Parameterless `F.rms_norm` after each decoder skip-connection. + +**Observations**: Residual norm growth (19.7→89.5) was root cause of poor quantization in later layers. RMSNorm keeps norms bounded → flatter weight distributions → lower quant error. + +**Verdict**: ✅ Validated. + +--- + +### Exp28b — Per-Layer Quantization Bitwidth (from exp09) + +| Field | Value | +|-------|-------| +| **Folder** | `exp28b_perlayer-quant_from-exp09/` | +| **Based on** | exp09 | +| **Status** | Analysis only | +| **MSE improvement** | -16.13% | +| **Artifact** | 16.24 MB | + +**Change**: Higher bitwidths for boundary layers (0,1,8,9): int7 attn, int6 MLP. + +**Observations**: Cuts boundary quant error in half but blows 16MB budget. + +**Verdict**: ❌ Not viable. Size kills it. + +--- + +### Exp29b — Loss Weighting + Token-Type Embedding (from exp09) + +| Field | Value | +|-------|-------| +| **Folder** | `exp29b_lossweight-typemb_from-exp09/` | +| **Based on** | exp09 | +| **Steps** | 1207 | +| **Raw BPB** | 1.3005 | +| **Quant BPB** | **1.3176** | +| **Artifact** | 15.75 MB | +| **Extra params** | +8,305 | + +**Change**: (1) Per-token loss weighting: 1.5x word-start, 0.8x easy suffixes. (2) 7-category token-type embedding: 7×16 + 16×512 proj + learned scale. Zero-initialized. + +**Observations**: Strong win. Loss weighting redistributes gradient to high-opportunity tokens. Token-type gives explicit structural signal. + +**Verdict**: ✅ Validated. + +--- + +### Exp30b — Combo: Resid-Norm + Loss-Weight + Token-Type (from exp09) ⭐ BEST + +| Field | Value | +|-------|-------| +| **Folder** | `exp30b_combo_from-exp09/` | +| **Based on** | exp09 | +| **Steps** | 1200 | +| **Raw BPB** | 1.2983 | +| **Quant BPB** | **1.3156** | +| **Artifact** | **15.05 MB** | + +**Change**: All three validated improvements combined, each togglable via env vars. + +**Observations**: Best quant bpb. Gains sub-additive (expected -0.019, got -0.0126). Quant gap reduced 0.023→0.017. + +**Verdict**: ✅ **Phase 3b-Part1 SOTA. Base for all subsequent experiments.** + +--- + +### Exp31b — RoPE Base 50k (from exp30b) + +| Field | Value | +|-------|-------| +| **Folder** | `exp31b_rope-headspec_from-exp30b/` | +| **Based on** | exp30b | +| **Steps** | 1197 | +| **Raw BPB** | **1.2953** | +| **Quant BPB** | 1.3206 | +| **Artifact** | 15.01 MB | + +**Change**: RoPE base 10000→50000. + +**Observations**: Best raw bpb ever but quant gap widens to 0.0253 vs 0.0173 for exp30b. Net NEGATIVE after quantization. + +**Verdict**: ❌ Raw gain eaten by quant degradation. + +--- + +### Exp32b — Auxiliary Word-Boundary Loss (from exp30b) + +| Field | Value | +|-------|-------| +| **Folder** | `exp32b_aux-boundary_from-exp30b/` | +| **Based on** | exp30b | +| **Steps** | 1196 | +| **Raw BPB** | 1.3153 | +| **Quant BPB** | 1.3424 | +| **Artifact** | 15.68 MB | + +**Change**: Auxiliary head (512→1, sigmoid, binary CE, λ=0.15) predicting word-start. + +**Observations**: Significant regression (+0.0268). Aux loss diverts gradient. Token-type already provides structural signal. + +**Verdict**: ❌ Auxiliary losses counterproductive. + +--- + +### Exp33b — Alternating RoPE Bases + NTK Scaling (from exp30b) + +| Field | Value | +|-------|-------| +| **Folder** | `exp33b_swa-attn-ntkrope_from-exp30b/` | +| **Based on** | exp30b | +| **Steps** | 1200 | +| **Raw BPB** | 1.2971 | +| **Quant BPB** | 1.3145 | +| **Artifact** | 14.94 MB | + +**Change**: Even blocks rope_base=50000, odd blocks rope_base=1000. NTK scaling. + +**Observations**: Marginal improvement (-0.0011). Positional loss curve STILL flat after 256 tokens. Improvement is likely noise. + +**Verdict**: ⚠️ Marginal. exp30b preferred for simplicity. + +## 3.3 Deep Analysis Insights (from exp33b checkpoint) + +- **Word-start tokens**: 27.7% of tokens but **44.2% of total loss**. Mean loss 3.72 vs 1.28 for "other". +- **Softcap not saturating**: Max |logit| = 28.2 (94% of cap=30). +- **Word-start boost learned DOWN to 0.16**: Model suppresses bigram at word boundaries — hash collisions may be confusing. +- **Token-type embedding actively used**: Scale grew 0.05→0.53. Punctuation/whitespace highest norms. +- **Top-1 accuracy: 46.6%, Top-5: 68.9%** +- **First 16 tokens**: Loss 2.96 vs 2.27 for rest (+0.70 penalty). Cold-start tokens disproportionately hard. +- **Positional loss flat after 256**: 2.53→2.33→flat ~2.2-2.4. + +## 3.4 Evolution Tree + +``` +exp09 (pad ignore + word-start boost) ──── quant 1.3282 + ├── exp27b (resid-norm) ──────────────── quant 1.3197 ✓ + ├── exp28b (perlayer quant) ──────────── N/A (over 16MB) ✗ + ├── exp29b (loss-weight + type-emb) ──── quant 1.3176 ✓ + └── exp30b (combo: 27b+29b) ─────────── quant 1.3156 ⭐ BEST + ├── exp31b (RoPE 50k) ──────────── quant 1.3206 ✗ (quant gap) + ├── exp32b (aux boundary loss) ──── quant 1.3424 ✗✗ (gradient waste) + └── exp33b (alternating RoPE+NTK) ─ quant 1.3145 ⚠️ (marginal, noisy) +``` + +## 3.5 Lessons Learned (Phase 3b-Part1) + +8. **Residual norm control is high leverage**: RMSNorm after skip connections attacks root cause (norm growth) not symptom. +9. **Stacking orthogonal improvements works**: exp30b combined 3 independent improvements for sub-additive but substantial gain. +10. **Auxiliary losses fatal in compute-starved regimes**: ~1200 steps = every gradient must reduce CE. +11. **RoPE base changes hurt quantization**: Higher base → harder-to-compress weight distributions. +12. **Long context is a dead end**: Positional loss flat after 256 tokens regardless of RoPE. +13. **Size budget matters**: Per-layer quant delivers 16% MSE reduction but can't fit in 16MB. + +--- + +# 4. Phase 3b-Part2: LR Fix Era (exp34b–exp48b) + +> **Hardware**: 1×H100, 600s wallclock +> **Base**: exp30b with LR schedule fix +> **Critical discovery**: ITERATIONS=20000 meant warmdown NEVER fired. Fixing to 1300 = -0.0166 bpb. +> **Best result**: exp48b — quant bpb **1.2930** + +All experiments include the **LR schedule fix** (ITERATIONS=1300, WARMDOWN_ITERS=400). + +## 4.1 Leaderboard + +| Rank | Exp | Name | Base | Quant BPB | Raw BPB | Size | Steps | Verdict | +|:----:|:---:|------|:----:|:---------:|:-------:|:----:|:-----:|:-------:| +| **1** | **48b** | 10blocks-depth | 42b | **1.2930** | 1.2824 | 15.22 MB | 1198 | ✅ Best | +| 2 | 45b | awq-alpha07 | 42b | 1.2897* | — | 14.01 MB | — | ✅ Post-train only | +| 3 | 42b | revive-block9 | 34b | 1.2969 | 1.2867 | 14.01 MB | 1196 | ✅ Layer sharing | +| 4 | 39b | swa-tuning | 34b | 1.2969 | 1.2867 | — | ~1196 | ≈ Tie with 42b | +| 5 | 34b | lr-schedule-fix | 30b | 1.2990 | 1.2891 | 15.13 MB | 1186 | ✅ LR fix breakthrough | +| 6 | 46b | full-mha | 42b | 1.2979 | 1.2896 | 15.59 MB | 1164 | ❌ 8KV not worth it | +| 7 | 43b | boundary-boost | 42b | 1.2993 | — | 14.01 MB | ~1196 | ❌ Too sparse | +| 8 | 47b | warmdown200 | 42b | 1.3081 | 1.2983 | 14.07 MB | 1200 | ❌ Too late warmdown | +| 9 | 37b | fused-cap | 34b | 1.3159 | 1.3014 | 15.17 MB | 1197 | ❌ Cap hurts | +| 10 | 35b | focal-loss | 30b | 1.3201 | — | 15.20 MB | 1196 | ❌ γ=2 too aggressive | +| 11 | 36b | cappedact-labelsmooth | 30b | 1.3248 | — | 14.50 MB | ~1180 | ❌ Both hurt | +| 12 | 44b | seqlen-curriculum | 42b | ~1.32 | — | — | ~1000 | ❌ Failed | +| 13 | 38b | speed-opt | 34b | — | — | — | — | ❌ Failed | + +*exp45b is post-training AWQ alpha tuning, not a retrain. + +## 4.2 Key Discoveries + +14. **LR schedule was completely broken** (exp34b): ITERATIONS=20000 with 600s wallclock meant warmdown NEVER fired. Fixing to ITERATIONS=1300 gave -0.0166 quant bpb — single biggest improvement. +15. **Layer sharing revives dead blocks** (exp42b): Block 9 dead at 6.1% effective rank. Sharing block 3 at position 9 revived it to 10.3%. +16. **More depth beats more width** (exp48b vs exp46b): 10th unique block (-0.0039 bpb) > 8 KV heads (+0.001 bpb). +17. **Auxiliary losses still fatal** (exp35b, exp43b): Focal loss (γ=2), boundary boost — both hurt. +18. **Activation capping hurts training** (exp36b, exp37b): Warmdown already produces smooth-enough weights. +19. **AWQ alpha is undertested** (exp45b): Sweeping alpha 0.6→0.7 gave -0.007 bpb for free. +20. **Warmdown=400 is optimal** (exp47b): warmdown=200 too late — only 2 SWA checkpoints. +21. **Calibration improved dramatically**: With warmdown, near-perfect calibration (all bins within ±0.003 gap). + +## 4.3 Deep Analysis Insights (from exp48b checkpoint) + +- **Word-start tokens**: 25.5% of tokens but **42.5% of total loss**. Mean loss 3.61 vs 1.20. Top-1 accuracy only 24.7%. +- **Top confusion pairs**: `▁and`→`,` (299), `▁the`→`▁a` (263), `▁and`→`.` (200). Function word disambiguation is the core problem. +- **Positional loss flat after 256**: 2.47→2.06→flat ~2.1-2.2. +- **MLP activation outliers**: Block 1 max=1314, 10-15% of activations >4.0. Root cause of 6.86% MLP quant error. +- **Calibration near-perfect**: All bins within ±0.003. +- **Block 3 (shared) under stress**: Effective rank 57.9% — lowest of all blocks. +- **728 KB headroom**: Tighter than exp42b's 2.09 MB. +- **Sentence boundary gap**: 0.78 bpb (after-boundary 2.93 vs normal 2.15). + +--- + +# 5. Phase 3b-Part3: Simplification + XSA (exp53b–clean_54b) + +> **Hardware**: 1×H100, 600s wallclock +> **Key insight**: Removing features improved results +> **Best result**: exp54b — quant bpb **1.2708** (1×H100 SOTA) + +## 5.1 Experiment Cards + +### exp53b — Lean Combo (from exp48b) + +| Field | Value | +|-------|-------| +| **Quant BPB** | 1.2720 | +| **Raw BPB** | 1.2640 | +| **Steps** | 1218 | +| **ms/step** | 492 | +| **Changes** | Stripped token-type + loss weighting, kept resid-norm. max-autotune-no-cudagraphs. VAL_LOSS_EVERY=800. WARMUP=20. | + +--- + +### exp54b — XSA + zstd + c_k Fix (from exp53b) ⭐ 1×H100 BEST + +| Field | Value | +|-------|-------| +| **Quant BPB** | **1.2708** | +| **Raw BPB** | 1.2642 | +| **Steps** | 1235 | +| **ms/step** | 486 | +| **Changes** | XSA on last 2 decoder layers. Fixed block 0 c_k outlier (fp16 keep). Reverted to zstd. | + +--- + +### exp55b — Scaled XSA All Layers (from exp54b) + +| Field | Value | +|-------|-------| +| **Quant BPB** | 1.2717 (marginal regression) | +| **Raw BPB** | 1.2648 | +| **Steps** | 1183 | +| **ms/step** | 507 | +| **Changes** | Learned `xsa_alpha = sigmoid(param)` per layer on ALL 10 blocks. | + +**Finding**: Model learned alpha 0.75-0.99 on ALL layers — XSA universally wanted. But 507ms vs 486ms = 52 fewer steps, erasing benefit on 1×H100. + +--- + +### exp56b — Fast XSA (Cosine Scale) (from exp55b) + +| Field | Value | +|-------|-------| +| **Status** | Killed early — no speed improvement | +| **ms/step** | 508 (same as exp55b) | + +**Finding**: GQA head expansion (`repeat_interleave` 4→8) is the bottleneck, not the XSA math. + +--- + +### Community Model B — Fair 1×H100 Comparison + +| Field | Value | +|-------|-------| +| **Quant BPB** | 1.2825 | +| **Raw BPB** | 1.2501 | +| **Steps** | 1479 | +| **ms/step** | 406 | +| **Key difference** | Faster per-step (partial RoPE, no resid-norm) but 5× worse quant gap (0.032 vs 0.007). | + +--- + +### exp58b vs exp59b vs exp54b — Resid-Norm A/B Test + +| Config | Steps | ms/step | Quant val_bpb | +|--------|-------|---------|:-------------:| +| exp54b (no norm) | 1235 | 486 | **1.2708** | +| exp58b (post-addition norm) | 1216 | 493 | 1.2741 | +| exp59b (pre-skip norm) | ~1216 | ~493 | ~1.274 | + +**Conclusion**: With warmdown active, resid-norm is redundant. The 7ms/step overhead costs more than the quant gap benefit. + +--- + +### clean_54b — Final Architecture (from exp54b) + +| Field | Value | +|-------|-------| +| **Quant BPB** | 1.2723 | +| **Changes** | exp54b + vanilla TTT + named checkpoint save + logging fixes. No architectural changes. | + +--- + +### Failed Experiments (Phase 3b-Part3) + +| Exp | What | Why it failed | +|-----|------|--------------| +| 35b | Focal loss (gamma=2) | Too aggressive — suppressed easy token gradients | +| 36b | Capped act + label smooth | Both hurt independently and together | +| 37b | Fused cap (no label smooth) | Cap=4.0 hurt raw quality more than helped quant | +| 43b | Boundary loss boost | Too sparse (2.5% of positions) | +| 44b | Seq-len curriculum | Speed regression | +| 46b | Full MHA (8 KV heads) | Extra params but slower, no bpb improvement | +| 55b | Scaled XSA all layers | 20ms/step overhead costs 52 steps | +| 56b | Fast cosine XSA | GQA head expansion is the bottleneck | +| 58b | Resid-norm re-enabled | 7ms/step → 19 fewer steps, redundant with warmdown | +| 59b | Pre-norm skip | Same overhead, no quality difference | + +## 5.2 Lessons Learned (Phase 3b-Part3) + +1. **LR warmdown is critical** — Biggest single improvement (0.017 bpb). +2. **Simpler is better** — Stripping token-type and loss weighting HELPED. +3. **Steps > features** — Every ms/step matters. Features that add compute must justify their cost. +4. **Resid-norm is redundant with warmdown** — Weights already smooth. +5. **XSA last 2 is the sweet spot** — Model wants XSA everywhere but overhead makes all-layer too expensive on 1×H100. +6. **zstd > LZMA** — Better for structured quantized weights (15.19 MB vs 15.37 MB). +7. **torch.compile mode matters** — `max-autotune-no-cudagraphs` gives kernel autotuning without tensor overwrite issues. + +## 5.3 TTT (Test-Time Training) Note + +TTT requires matching the torch.compile context used during training for correct inference results. + +## 5.4 Complete Phase 3b Lineage (exp27b → clean_54b) + +``` +exp09 (pad-ignore + word-start boost, quant 1.3282) ← PHASE 3a BEST + │ + ├── exp27b [✅ POSITIVE] resid-norm (quant 1.3197, Δ-0.0085) + ├── exp28b [❌ NEGATIVE] perlayer quant (over 16MB budget) + ├── exp29b [✅ POSITIVE] loss-weight + token-type (quant 1.3176, Δ-0.0106) + │ + └── exp30b [✅ POSITIVE] combo: resid-norm + loss-weight + token-type (quant 1.3156, Δ-0.0126) + │ + ├── exp31b [❌ NEGATIVE] RoPE 50k (better raw 1.2953 but worse quant 1.3206) + ├── exp32b [❌ NEGATIVE] aux boundary loss (quant 1.3424) + ├── exp33b [⚪ NEUTRAL] alternating RoPE + NTK (quant 1.3145, marginal) + │ + └── exp34b [✅✅ MAJOR] LR schedule fix (quant 1.2990, Δ-0.0166 — biggest single win) + │ + ├── exp35b [❌ NEGATIVE] focal loss γ=2 (quant 1.3201) + ├── exp36b [❌ NEGATIVE] capped act + label smooth (quant 1.3472) + ├── exp37b [❌ NEGATIVE] fused cap only (quant 1.3159) + ├── exp38b [⚪ NEUTRAL] speed opt (quant 1.3002) + ├── exp39b [✅ POSITIVE] SWA tuning (quant 1.2985) + │ + └── exp42b [✅ POSITIVE] layer sharing block 3→pos 9 (quant 1.2969) + │ + ├── exp43b [⚪ NEUTRAL] boundary loss boost (quant 1.3003) + ├── exp44b [❌ NEGATIVE] seq-len curriculum (failed) + ├── exp45b [⚪ NEUTRAL] AWQ alpha=0.7 (quant 1.3033) + ├── exp46b [⚪ NEUTRAL] full MHA 8 KV heads (quant 1.2979) + ├── exp47b [❌ NEGATIVE] warmdown=200 (quant 1.3081) + │ + └── exp48b [✅ POSITIVE] 10th unique block (quant 1.2930) + │ + ├── exp49b [⚪ NOT RUN] diffusion GPT + ├── exp50b [⚪ NOT RUN] byte-level JEPA + │ + └── exp53b [✅ POSITIVE] strip overhead (quant 1.2720, Δ-0.0210) + │ + └── exp54b [✅ POSITIVE] XSA last 2 + c_k fix (quant 1.2708) ← 1xH100 BEST + │ + ├── exp55b [⚪ NEUTRAL] scaled XSA all layers (quant 1.2717) + ├── exp56b [❌ NEGATIVE] fast cosine XSA (no speed gain) + ├── exp57b [❌ NEGATIVE] LoRA TTT (failed) + ├── exp58b [❌ NEGATIVE] resid-norm ON (quant 1.2741) + ├── exp59b [❌ NEGATIVE] pre-norm skip (quant 1.2743) + │ + └── clean_54b [✅ POSITIVE] named save + TTT (quant 1.2723) + └── clean_54b_v2 [❌ NEGATIVE] bf16 roundtrip (destroyed quality) + +Community B model (fair 1xH100): quant 1.2825 ← WE BEAT BY 0.012 bpb +``` + +## 5.5 Complete Results Table (Phase 3b, sorted by quant bpb) + +| Rank | Exp | Quant BPB | Raw BPB | Steps | ms/step | Tag | Key Change | +|:----:|:---:|:---------:|:-------:|:-----:|:-------:|:---:|------------| +| **1** | **exp54b** | **1.2708** | 1.264 | 1235 | 486 | ✅ | XSA last 2 + c_k fix | +| 2 | exp53b | 1.2720 | 1.264 | 1218 | 492 | ✅ | Strip overhead | +| 3 | clean_54b | 1.2723 | 1.264 | 1205 | 498 | ✅ | Named save | +| 4 | community | 1.2825 | 1.250 | 1479 | 406 | — | Their full arch | +| 5 | exp48b | 1.2930 | 1.282 | 1198 | 501 | ✅ | 10th unique block | +| 6 | exp42b | 1.2969 | 1.287 | 1201 | 502 | ✅ | Layer sharing | +| 7 | exp39b | 1.2985 | — | 1196 | 502 | ✅ | SWA tuning | +| 8 | exp34b | 1.2990 | 1.289 | 1186 | 506 | ✅✅ | LR schedule fix | +| 9 | exp38b | 1.3002 | 1.290 | 1196 | 502 | ⚪ | Speed opt | +| 10 | exp43b | 1.3003 | 1.290 | 1198 | 501 | ⚪ | Boundary boost | +| 11 | exp45b | 1.3033 | — | 1196 | 502 | ⚪ | AWQ α=0.7 | +| 12 | exp47b | 1.3081 | 1.298 | 1200 | 500 | ❌ | Warmdown=200 | +| 13 | exp33b | 1.3145 | — | — | — | ⚪ | Alt RoPE + NTK | +| 14 | exp30b | 1.3156 | 1.298 | 1200 | 500 | ✅ | Combo | +| 15 | exp37b | 1.3159 | 1.301 | 1197 | 501 | ❌ | Fused cap | +| 16 | exp29b | 1.3176 | 1.301 | 1202 | 499 | ✅ | Loss-wt + type-emb | +| 17 | exp27b | 1.3197 | ~1.300 | ~1200 | ~500 | ✅ | Resid-norm | +| 18 | exp35b | 1.3201 | — | 1196 | — | ❌ | Focal loss γ=2 | +| 19 | exp31b | 1.3206 | 1.295 | 1197 | 502 | ❌ | RoPE 50k | +| 20 | exp32b | 1.3424 | 1.315 | 1196 | 502 | ❌ | Aux boundary loss | +| 21 | exp36b | 1.3472 | 1.332 | 1135 | 529 | ❌ | Cap + label smooth | + +## 5.6 Remaining Opportunities (identified at this stage) + +### Zero-cost (no retraining, apply to exp54b checkpoint) +1. AWQ alpha sweep on exp54b (test alpha=0.3-0.8) +2. Pruning threshold sweep (0%, 1%, 2%, 5%) + +### Quick retrains (10 min each) +3. Seed sweep (43, 44, 45). Variance: 0.003-0.005 bpb. +4. Weight decay tuning (0.04→0.06) +5. LR tuning (MATRIX_LR 0.025→0.030) +6. EMA with decay=0.995 (replace SWA) + +### Creative submissions +7. Diffusion GPT (exp49b) — Hybrid masked diffusion + AR +8. Byte-level JEPA (exp50b) — Raw byte model + +### For 8×H100 scaling +9. XSA on all layers (alpha=0.75-0.99 everywhere) +10. Partial RoPE (16 dims) — community speed trick +11. Late QAT — quant noise in last 15% +12. Predicted quant val_bpb ~1.12-1.14 + +--- + +# 6. Phase 3.5: 8×H100 Simulation (exp60–exp80) + +> **Hardware**: 1×H100 simulating 8×H100 (6000s wallclock, grad_accum=8, 786K tokens/batch) +> **Base**: exp54b (clean baseline) +> **Best result**: exp74 — sliding bpb **1.1456**, artifact 15.86 MB + +## 6.1 Leaderboard + +| Exp | Description | Pre-quant BPB | Post-quant BPB | Sliding BPB | Artifact | Steps | Status | +|-----|-------------|:-------------:|:--------------:|:-----------:|:--------:|:-----:|:------:| +| **exp74** | pRoPE+qgain+wbigram+LLR | 1.1539 | 1.1685 | **1.1456** | 15.86 MB | 6169 | **Best** | +| exp70 | Speed-optimized from exp69 | ~1.14 | ~1.17 | ~1.15 | ~16 MB | ~7500 | Baseline | +| exp78 | WS loss curriculum | — | — | — | — | — | Better embeddings | +| exp75 | Word pool from exp74 | — | — | — | — | — | Failed (scale→0.002) | +| exp61b | XSA all + warmdown | 1.1504 | 1.1781 | — | ~16.5 MB | ~7000 | Over budget | +| exp63 | Cascade VR + adaptive WD | 1.1377 | 1.1730 | — | 16.45 MB | ~7000 | Over budget | + +## 6.2 Evolution Tree + +``` +exp54b (clean baseline, 1.2708 bpb) + └── exp60 (EMA, flash_attn3, 8×H100 sim) + └── exp61b (XSA all blocks, cosine warmdown → 1.1504 pre-quant) + └── exp63 (cascading V-residual, adaptive warmdown → 1.1377 pre-quant) + ├── exp64 (MLP int6 quant — never ran) + ├── exp65 (quant overhaul — never ran) + │ └── exp66 (MiLe loss + NoPE — failed) + │ ├── exp67 (word-start semantic attention — failed) + │ └── exp68 (next-word-start MTP — not run) + └── exp69 (better quant: mlp_proj int6, attn int5, lzma, prune 5%) + └── exp70 (speed: batched NS5, EMA/10, set_to_none → 1.15 bpb) + ├── exp71 (output bias + label smooth — not run) + ├── exp72 (JEPA concept loss — failed) + ├── exp73 (warmdown focal + TTT WS — not run) + ├── exp74 (pRoPE 16/64 + q_gain + word bigram + LLR → **1.1456**) + │ ├── exp75 (word pool injection — failed: scale→0) + │ └── exp76 (dual word attention — failed) + ├── exp77 (progressive batch + seq_len curriculum) + ├── exp78 (WS loss curriculum — improved embeddings) + ├── exp79 (position ramp + late WS boost) + └── exp80 (best stack: pRoPE + bigram fix + pos ramp + outlier clamp) +``` + +## 6.3 Key Findings + +### What Worked +1. **Partial RoPE 16/64** (exp74): 41% less quant error, better word-start attention. Frees 75% of head dims for semantic matching. +2. **Diverse q_gain init** (exp74): Heads specialized faster — sharp (>2.5) for syntax, soft (<1.5) for semantics. +3. **Cascading value residual** (exp63→all): Shallow layers independent (α≈0), deep layers form value highway (α≈0.9). +4. **Better quantization** (exp69): MLP proj→int6 (3.4× less error), attn→int5 (size-neutral), magnitude pruning 5%, lzma. +5. **Speed optimizations** (exp70): Batched NS5 via bmm, EMA every 10 steps, set_to_none=True, deferred .item(). + +### What Failed +1. **MiLe Loss** (exp66): Downweighted easy tokens before consolidation. +2. **JEPA concept loss** (exp72): Added memory/overhead, not enough steps. +3. **Word pool injection** (exp75): Model drove scale to 0.002 — redundant. +4. **Output bias** (exp71): Needs ~500 steps to build momentum — too slow. +5. **Focal loss during training** (exp35b, exp73): Always hurts easy token accuracy. +6. **CUDAGraphs + tied embeddings** (exp66): Incompatible, caused failure. + +### Key Architectural Insights from Analysis +1. **Row 78 is a universal outlier**: 9/10 mlp.proj blocks have dim 78 as worst outlier (10-22× ratio). Per-row clamping ±3σ addresses this. +2. **Embedding uses 17% of capacity**: Effective rank 87/512. Word-start tokens (325) in 44 effective dims. +3. **Word-start norms 12% lower**: Tied embeddings structurally bias toward continuations (larger norm → higher logit). +4. **Deep layers form value highway**: VR alphas 7-10 → 0.9+ (strong V inheritance). Layers 2-5 independent. +5. **Block 0 attention barely used**: attn_scale=0.10. MLP-dominant layer. +6. **Block 4 c_q condition number 49,644**: Most quantization-sensitive matrix. + +### Word-Start Problem Analysis +- Word-start tokens: 40.8% of tokens, 5.05 bpb → **66% of total loss** +- Continuation tokens: 48.3%, 1.56 bpb → 24% of total loss +- Root causes: (1) full RoPE starves semantic attention, (2) tied embeddings bias continuations, (3) uniform gradient allocation, (4) bigram token-level not word-level +- Best fix: partial RoPE (architectural, not loss manipulation) + +### Community Gap Analysis (vs 1.1147 bpb) + +| Feature | Community | Ours (exp74) | Gap | +|---------|-----------|-------------|-----| +| Partial RoPE | 16/64 ✓ | 16/64 ✓ | Closed | +| GPTQ quantization | Full Hessian ✓ | Per-row uniform | **Open** | +| Bigram | 3072×112 | 10240×128 (larger) | Ours bigger | +| Warmdown | 4000 iters | 1500 (premature trigger) | **Open** | +| Compression | LZMA-9 ✓ | LZMA ✓ | Closed | +| TTT | Dropped (negative) | Enabled | Different | +| Selective pruning | ±1 reconstruction | 5% magnitude | Different | + +--- + +# 7. Phase 3.6: Diagnostic-Driven Era (exp83–exp87) + +> **Hardware**: 1×H100 (simulating 8×H100) +> **Philosophy shift**: Understand the model first, then act +> **Best result**: exp85 — pre-quant **1.1517**, post-quant 1.1697, artifact 15.32 MB + +## 7.1 Leaderboard + +| Exp | Description | Pre-quant BPB | Post-quant BPB | Artifact | Status | +|-----|-------------|:-------------:|:--------------:|:--------:|:------:| +| **exp85** | Community-derived stack | **1.1517** | 1.1697 | **15.32 MB** | **Best pre-quant** | +| **exp74** | pRoPE+qgain+wbigram+LLR | 1.1539 | **1.1685** | 15.86 MB | **Best post-quant** | +| exp87 | Fast convergence (failed) | ~1.17 | — | — | Failed | +| exp84 | Diagnostic-tuned (failed) | ~1.17 | — | — | Failed | +| exp83 | Diagnostics baseline | ~1.15 | 1.1717 | ~16 MB | Diagnostic reference | + +## 7.2 Evolution Tree + +``` +exp70 (speed-optimized baseline) + ├── exp83 (diagnostics: grad norms, VR health, bigram, block0 attention) + │ → Key finding: warmdown triggers at step 2200 (premature) + │ → Key finding: embed/matrix ratio 3.6→7.3× (misleading for Muon) + │ → Key finding: VR highway at layers 8-10, dead at 2-5 + ├── exp84 (diagnostic-tuned: VR_init=0.3, embed_lr=0.015) + │ → FAILED: VR alphas went negative, embed_lr change made ratio worse + │ → Lesson: VR_INIT must be 0.5, embed_lr ratio is misleading + ├── exp85 (community-derived: pRoPE + x0-to-V + LN scale + clip search) + │ → **1.1517 pre-quant** (best), 15.32 MB artifact + │ → VE scale learned: block 8=0.88 (wants identity), block 9=0.08 + │ → VR exploded to 3.26 at layer 6 (LN scale instability) + │ → Row 78 outlier: 4.5 (3× improved from exp70's 14.6) + ├── exp86 (deep-opt: fused QKV + int8 critical + TF32) + │ → Not yet run + └── exp87 (fast convergence: embed preinit + prog unfreeze + block9 AdamW) + → FAILED: embed preinit worse than random, prog unfreeze hurt co-adaptation + → Lesson: don't fight orthogonal init + Muon +``` + +## 7.3 Key Findings + +### What Worked (exp85) +1. **Partial RoPE 16/64**: Consistent across exp74 and exp85. Row 78 outlier 3× reduced. +2. **x0-to-V injection**: Block 8 grew ve_scale 0.3→0.88 — model WANTS token identity in deep-layer values. +3. **Clip search quantization**: Percentile-based clip per row. 25% quant error reduction. Zero training cost. +4. **Smaller bigram 5120×64**: 0.97 MB savings, artifact at 15.32 MB. +5. **Late warmdown min_steps=3000**: Delayed trigger from 2200 to 3100. + +### What Failed +1. **CASCADE_VR_INIT < 0.5**: Both 0.1 and 0.3 caused negative VR alphas. +2. **Lowering TIED_EMBED_LR**: 0.035→0.015 made ratio worse (10.4×). Muon normalizes direction differently. +3. **Embedding pre-init from SVD**: val_loss=12.21 at step 0 (vs 6.93 random). Incompatible with orthogonal weights. +4. **Progressive layer unfreezing**: Prevented deep-shallow co-adaptation. VR highway didn't form. +5. **Block 9 QKV → AdamW**: Duplicate parameter issue, inconclusive. +6. **LN Scale 1/√(layer+1)**: VR alpha explosion at layers 6-7 (3.26×). + +### Diagnostic Insights (from exp83) +- Block 0 attention dies by step 2000 (structural, not fixable) +- Block 1 x0_mix amplifies to 1.95× (compensates for dead block 0) +- Bigram scale decays 0.26→0.10 (attention supersedes local patterns) +- Grad clip never fires (threshold 0.3, actual norms 0.05-0.17) +- Loss oscillates ±0.07 during warmdown with 1500 iters (need 3500) + +### Remaining Gap to Community (1.1147 bpb) + +| Feature | Status | Estimated Impact | +|---------|--------|:----------------:| +| Partial RoPE | ✅ Matched | — | +| x0-to-V (vs community VE) | ✅ Novel alternative | Similar | +| Warmdown 3500 | ✅ Matched | — | +| Clip search | ✅ Adopted | -25% quant error | +| **Full Hessian GPTQ** | ❌ Not implemented | ~0.010 bpb | +| **VR alpha clamping** | ❌ Needed | Fix VR explosion | +| **LN Scale fix** | ❌ Needs investigation | TBD | +| Smaller bigram | ✅ Done | -0.97 MB | + +## 7.4 Complete Lineage (exp60–exp87) + +``` +exp54b (clean baseline, quant bpb 1.2708) +│ +├── exp60 (EMA, flash_attn3, 8×H100 sim) 🟡 +│ └── exp61b (XSA all blocks) 🟢 Pre-quant 1.1504 +│ └── exp63 (cascading V-residual) 🟢 Pre-quant 1.1377 +│ │ +│ ├── exp64 (MLP int6 quant) 🟡 Never ran +│ ├── exp65 (quant overhaul) 🟡 Never ran +│ │ └── exp66 (MiLe loss + NoPE) 🔴 MiLe hurt convergence +│ │ ├── exp67 (word-start semantic attention) 🔴 failed +│ │ └── exp68 (next-word-start MTP) 🟡 Never ran +│ │ +│ └── exp69 (better quant) 🟢 Closed gap 0.035→0.015 +│ └── exp70 (speed-optimized) 🟢 BASELINE +│ ├── exp71 (output bias) 🟡 Never ran +│ ├── exp72 (JEPA concept) 🔴 overhead, no improvement +│ ├── exp73 (warmdown focal) 🟡 Never ran +│ ├── exp74 (pRoPE + q_gain + word bigram) 🟢 BEST sliding 1.1456 +│ │ ├── exp75 (word pool) 🔴 scale→0.002 +│ │ └── exp76 (dual attention) 🔴 failed +│ ├── exp77old (late warmdown) 🟡 +│ ├── exp77 (progressive batch) 🟡 Never ran +│ ├── exp78 (WS loss curriculum) 🟢 Best embedding quality +│ │ └── exp81 (pRoPE + WS curriculum) 🟡 failed +│ │ └── exp82 (drop layer 10) 🟡 Never ran +│ ├── exp79 (position ramp) 🔴 premise wrong +│ ├── exp80 (best stack) 🔴 bigram-after-norm backfired +│ ├── exp83 (diagnostics) 🟢 7 actionable insights +│ ├── exp84 (diagnostic-tuned) 🔴 VR negative, embed_lr worse +│ └── exp85 (community-derived) 🟢 BEST pre-quant 1.1517 +│ └── exp86 (deep-opt) 🟡 Not yet run +│ └── exp87 (fast convergence) 🔴 All 3 changes hurt +``` + +## 7.5 Summary Statistics (exp60–exp87) + +| Outcome | Count | Examples | +|---------|-------|---------| +| 🟢 Positive | 8 | exp61b, exp63, exp69, exp70, exp74, exp78, exp83, exp85 | +| 🟡 Neutral | 9 | exp60, exp64, exp68, exp71, exp73, exp77old, exp77, exp82, exp86 | +| 🔴 Negative | 10 | exp66, exp67, exp72, exp75, exp76, exp79, exp80, exp84, exp87, exp65→66 | + +**Success rate: 29% positive, 36% neutral, 36% negative** + +--- + +# 8. Phase 3b-Muon: Parallel Muon Optimizer (exp70_parallel_muon–exp91) + +> **Base**: exp70_speed-opt_from_exp69 +> **Goal**: Faster training via Parallel Muon optimizer +> **Best result**: val_bpb **1.1440** (exp70_faster_version_parallel_muon, step 7317, 1×H100) + +## 8.1 Lineage + +``` +exp70_speed-opt_from_exp69 (original, DDP, 750ms/step) +├── exp70_faster_version_parallel_muon [🟢 POSITIVE: 12% speed, same final bpb] +│ ├── exp70_faster_vram_optimized [🔴 NEGATIVE: data loading issue] +│ ├── exp70_cuda_graphs_fused [🔴 NEGATIVE: no improvement] +│ ├── exp90_copy_head [🟡 NEUTRAL: concept validated, 40ms overhead] +│ └── reverted_exp70 [🟢 POSITIVE: clean base with all fixes] +│ └── exp91_smooth_v0residual [🟡 NEUTRAL: pending validation] +``` + +## 8.2 Results Table + +| Exp | Name | step_avg | Final BPB | Quant BPB | Size | Tag | +|-----|------|:--------:|:---------:|:---------:|:----:|:---:| +| exp70_parallel_muon | Parallel Muon + Banks | **658ms** | **1.1440** | 1.1715 | 16.3MB | 🟢 | +| exp70_vram_opt | Double-buffer loader | 636ms | — | — | — | 🔴 | +| exp70_cuda_fused | CUDA Graphs + Triton | 662ms | — (higher loss) | — | — | 🔴 | +| exp90_copy | TopicCopyHead (hybrid freq+attn) | 698ms | — (partial) | — | — | 🟡 | +| reverted_exp70 | Clean parallel muon base | 656ms | 1.1440 | 1.1715 | 16.3MB | 🟢 | +| exp91_smooth | V0 residual + label smooth | — | — (pending) | — | — | 🟡 | + +## 8.3 Key Findings + +1. **Parallel Muon gives 12% speed** via reduce-scatter/all-gather overlap and bank-native batching +2. **Per-step convergence ~0.002-0.004 bpb worse** — different torch.compile graphs, init RNG ordering +3. **CUDA Graphs incompatible with FA3** — not usable together +4. **GPTQ requires Late QAT** — without QAT-adapted weights, Cholesky error cascades +5. **Adaptive warmdown is fragile** — v1 triggers on noise, v3 never triggers on oscillating loss. Pure time-based is robust. +6. **Copy mechanism validated**: 1.19 bpb copy advantage for repeated tokens, 1.77 bpb for word-start. +7. **Model self-analysis**: word_start_boost=0.017 (dead), cascading VR layers 1-8 ≈ 0 (dead), K_1 kurtosis=33.8 (outlier-heavy), byte tokens 96% cosine similar (confused) + +## 8.4 Lessons Learned + +7. **Double-buffering needs N >= grad_accum_steps buffers** — insufficient buffers cause issues +8. **Custom Triton kernels for elementwise ops rarely help** — torch.compile already fuses them; precision differences compound +9. **AWQ with weight-magnitude proxy is catastrophic** — must use real activation statistics from forward hooks +10. **Selective ±1 pruning (Code 2) > blind magnitude pruning** — targets least-impactful quantized values +11. **Init order matters for reproducibility** — nn.init.orthogonal_ consumes RNG; bank vs module ordering creates different trajectories + +--- + +# 9. Phase 3c: Architecture Rewrite + Meta-TTT (exp92–exp109) + +> **Hardware**: 8×H100 +> **Base**: exp70_speed-opt → exp92 (major rewrite) +> **Key finding**: Meta-TTT has an architecture-limited ceiling +> **Best result**: exp101 — legal_ttt **1.11588** + +## 9.1 Lineage + +``` +exp70_speed-opt (1.153 bpb) +└── exp92_banks-asyncmuon-partrope-qat-ve [🟢 1.131 bpb — major rewrite] + └── exp93_meta-ttt-inner-outer [🟢 1.120 legal_ttt] + └── exp95_size-ttt-opt-metattt2x [🟢 1.1169 legal_ttt — SOTA at time] + ├── exp96_warmdown-fix-trigram-sgdttt [🟡 ~1.135] + │ ├── exp98_metattt-randomsplit-momentum [🟡 ~1.135] + │ │ └── exp99_tripleloop-parallelres [🟡 not run] + │ └── exp97_fp8-pipeline [not run] + ├── exp101_poscond-bigram-trigram [🟢 1.11588 legal_ttt — new baseline] + │ ├── exp105a_no-metattt [🟡 ablation: meta-TTT = noise] + │ ├── exp106_metasgd-crosschunk [🟡 ceiling confirmed] + │ │ ├── exp107_sam-inner [🔴 hurts] + │ │ └── exp108_sp8192-brotli [🟡 no results] + │ └── exp109_shared-blocks-softgate [🔴 decoder dead] + └── exp100_half-metattt [not tracked here] +``` + +## 9.2 Results Table + +| Exp | Name | val_bpb | int6_bpb | legal_ttt | Tag | +|-----|------|:-------:|:--------:|:---------:|:---:| +| exp92 | Banks + Async Muon + Partial RoPE + QAT + VE | ~1.131 | — | — | 🟢 | +| exp93 | Meta-TTT inner/outer FOMAML | 1.136 | — | ~1.116 | 🟢 | +| exp95 | Size-opt + meta-TTT 2× | 1.1363 | — | 1.1169 | 🟢 | +| exp96 | Warmdown fix + trigram | ~1.135 | — | — | 🟡 | +| exp98 | Random-split FOMAML + momentum LR match | ~1.135 | — | — | 🟡 | +| exp99 | Triple loop + parallel residuals | — | — | — | 🟡 | +| **exp101** | **Position-conditional bigram hash** | **1.1352** | **1.13930** | **1.11588** | **🟢** | +| exp105a | No meta-TTT (ablation) | 1.1353 | 1.13956 | 1.11624 | 🟡 | +| exp106 | MetaSGD + cross-chunk FOMAML | 1.1377 | 1.14160 | ~1.118 | 🟡 | +| exp107 | SAM inner loop | 1.1384 | 1.1424 | 1.11898 | 🔴 | +| exp108 | SP8192 + Brotli | — | — | — | 🟡 | +| exp109 | Block sharing K=8 + SP8192 | 1.1500 | 1.1897 | — | 🔴 | + +## 9.3 Key Findings + +1. **Meta-TTT ceiling is architecture-limited**: 4 experiments (exp101, 105a, 106, 107) show identical TTT delta ~0.023 bpb regardless of optimizer (SGD, MetaSGD, SAM, none). Ceiling set by bank architecture (rank × dim). +2. **Position-conditional bigram hashing** (exp101): Zero-parameter trick — split hash space by token class (word-start vs within-word). +0.001 bpb. +3. **Block sharing fails across encoder/decoder boundary** (exp109): Shared decoder positions → near-zero scales. Soft gates diagnose but can't fix. +4. **SP8192 quant degradation 10× worse than SP1024** (exp109): Large embedding table (8192×512) poorly compressed. + +--- + +# 10. Phase 3c-Community: Community SOTA (SP8192+) + +> **Source**: Community contributions on parameter-golf repository +> **Impact**: Paradigm shift from our 1.1169 to 1.0744 bpb + +## 10.1 The Three Community Breakthroughs + +### 10.1.1 SP8192 + 3-Layer Recurrence (2026-04-09) + +| Metric | Value | +|--------|:-----:| +| **val_bpb** | 1.0873 | +| **int6_bpb** | 1.0997 | +| **legal_ttt** | **1.0808** | +| **Hardware** | 8×H100 | + +**Key innovations**: SP8192 tokenizer + 3-layer depth recurrence (blocks 3-5, 2 extra passes) + parallel residuals + QK_GAIN_INIT=5.25. 17 virtual layers from 11 physical. + +### 10.1.2 WiderEmb + Tap-In V6 + TTT (2026-04-10, community) + +| Metric | Value | +|--------|:-----:| +| **val_bpb** | 1.0813 | +| **int6_bpb** | 1.0980 | +| **legal_ttt** | **1.0788** | +| **3-seed mean** | 1.078825 | + +**Key innovations**: Wider loop (3×3) + per-pass loop embeddings (3×512, zero-init) + Tap-In V6 cross-window n-gram C++ matcher + legal score-first TTT. + +### 10.1.3 ImprovedParallelResiduals (2026-04-11, community PR #1523) — CURRENT BEST + +| Metric | Value | +|--------|:-----:| +| **legal_ttt val_bpb** | **1.07438** (3-seed mean) | +| **val_bpb_std** | 0.00034 | +| **Artifact** | 15,959,005 bytes (71 bytes headroom) | +| **Hardware** | 8×H100 80GB SXM | +| **step_avg_ms** | 124.68 | + +**Key innovations**: Richer parallel residual routing — attn/MLP outputs written into both lanes at block end, decoder skips on lane0 only. CUTLASS EVT fusion for reproducible throughput. + +**Seed results:** +| Seed | val_bpb | post_ema_val_bpb | artifact_bytes | steps | ms/step | +|:----:|:-------:|:----------------:|:--------------:|:-----:|:-------:| +| 1337 | 1.07485 | 1.08286 | 15,958,373 | 4685 | 125.53 | +| 2024 | 1.07428 | 1.08242 | 15,956,287 | 4734 | 124.25 | +| 42 | 1.07403 | 1.08212 | 15,959,005 | 4733 | 124.26 | + +## 10.2 Other Community-Adjacent Experiments + +| Exp | Date | Description | Tag | +|-----|------|-------------|:---:| +| 2026-04-10_RecurStepFiLM_PooledRetrieval | 2026-04-10 | FiLM conditioning + pooled retrieval | 🟡 | +| 2026-04-10_10L_RecurStepFiLM_PooledRetrieval | 2026-04-10 | 10L variant | 🟡 | +| 2026-04-11_ImprovedParallelResiduals copy | 2026-04-11 | Copy/variant | 🟡 | +| 2026-04-11_newSota | 2026-04-11 | Community SOTA integration | 🟢 | +| 2026-04-11_11L_RecurStep3_loopedonly | 2026-04-11 | 11L recurrence step 3, looped-only | 🟡 | +| 2026-04-11_11L_RecurStep3_loops3 | 2026-04-11 | 11L with 3 loops | 🟡 | +| 2026-04-11_11L_RecurStep_StochDepth_ProgLoop | 2026-04-11 | Stochastic depth + progressive loop | 🟡 | +| 2026-04-11_11L_RecurStep_StochDepth_ProgLoop_KVCache | 2026-04-11 | + KV cache for recurrence | 🟡 | +| 2026-04-11_11L_Block10MLPHalf_RecurStepFiLM_PooledRetrieval | 2026-04-11 | Block 10 MLP halved + FiLM | 🟡 | +| loop_in_SP8192_3LayerRecur | 2026-04-13 | Loop detection (timestep embed, re-injection, per-loop RMSNorm) | 🟡 not trained | + +--- + +# 11. Phase 3c-Frontier: Pushing Past Community (exp110–exp119) + +> **Base**: ImprovedParallelResiduals (1.0744 legal_ttt) +> **Theme**: Tied embedding bottleneck +> **Result**: No improvement over community baseline + +## 11.1 Results Table + +| Exp | Name | val_bpb | int6_bpb | legal_ttt | Size | Tag | +|-----|------|:-------:|:--------:|:---------:|:----:|:---:| +| exp110 | Per-layer quant + trigram + PARALLEL_START=7 | — | — | — | — | 🟡 | +| exp111 | LoRA TTT (rank=8) + shrunk block 10 MLP + per-layer int5 | — | — | — | — | 🟡 | +| exp112 | Gradient rescaling on weak blocks | — | — | — | — | 🔴 | +| exp113 | Drop L0 MLP + batch schedule + MTP | — | — | — | — | 🟡 | +| exp114 | embed_dim=384 decouple | 1.0950 | — | — | fits | 🔴 | +| exp115 | embed_dim=384 + drop boundary MLPs | — | — | — | — | 🟡 | +| exp116 | embed_dim=384 + no x0 pathway | — | — | — | — | 🔴 | +| exp117 | embed_dim=448 tuned | 1.0877 | 1.0982 | 1.0814 (SW) | **16.28MB** | 🔴 | +| exp118 | embed_dim=416 + parallel_start=7 + clip tuned | 1.0915 | 1.1013 | 1.0850 | **16.44MB** | 🔴 | +| exp119 | Residual low-rank proj (rank=32) | — | — | — | — | 🟡 | + +## 11.2 The Tied Embedding Bottleneck + +The dominant theme: the model uses the same weight matrix for input embeddings and output projection. With SP8192, this (8192×512) matrix dominates the parameter budget and forces boundary blocks (0 and 10) to specialize for embedding space rather than general computation. + +**Attempted fixes:** +- **embed_dim=448** (exp117): Good BPB (1.0877), activates boundary blocks (+50% effective contribution). But **16.28MB — over budget**. +- **embed_dim=416** (exp118): Similar story at **16.44MB**. +- **embed_dim=384** (exp114): Fits budget but loses 655K params → BPB regression. +- **Residual low-rank projection** (exp119): rank-32, zero param loss — theoretically correct fix. Not run to completion. + +**Verdict**: The bottleneck is real. embed_dim≠model_dim activates boundary blocks but any dimension-change approach costs either params (regression) or fp16 passthrough overhead (budget overrun). + +--- + +# 12. Misc: Co-occurrence QK Initialization + +> **Date**: 2026-03-24 +> **Hardware**: 1×H100 +> **Separate exploration from main competition track** + +| Metric | Value | +|--------|:-----:| +| **val_bpb** | 1.3525 | +| **Pre-quant val_bpb** | 1.3245 | +| **Artifact** | 15.55 MB | +| **Seeds** | 1 (seed 42) | +| **Steps** | 1099 | +| **Wallclock** | 600.138s | +| **Base PR** | #623 | + +**Approach**: Initialize W_Q and W_K in layer 0 from bigram co-occurrence statistics via SVD: +1. Build 1024×1024 co-occurrence matrix from 2M training tokens (<3s) +2. Project into model_dim via random projection +3. Factorize C_proj = USV^T → Q/K weights where Q·K^T ≈ co-occurrence at step 0 + +Combined with LeakyReLU(0.5)², cyclic momentum (0.85–0.95), SWA over warmdown. + +**Note**: exp87 later tried SVD-based embedding pre-initialization and it regressed. The difference: co-occurrence QK init changes attention *patterns*, while embedding SVD changes *representation space* (conflicts with Muon's orthogonal constraint). + +--- + +# 13. Known Constraints + +- **TTT requires compile-matched inference**: Standalone model loading needs the same torch.compile context as training for correct numerical results. +- **SP8192 quantization sensitivity**: Large embedding table (8192×512) needs GPTQ with SDClip — naive quantization degrades 10× worse than SP1024. +- **CUDA Graphs limited**: Incompatible with FA3 and tied embeddings in `reduce-overhead` mode. + +--- + +# 14. Key Learnings by Phase + +## 14.1 Phase 3a Lessons (exp00–exp18) + +1. **Step count is king**: Reducing grad_accum = biggest single win. +2. **Don't fight the quantizer**: Better convergence naturally produces better-quantizing weights. +3. **N-gram tables have diminishing returns**: Bigram valuable, trigram+ marginal. +4. **Hash collision reduction matters less than expected**: Model routes around collisions. +5. **Auxiliary losses are dangerous**: JEPA caused biggest regression. +6. **Simple targeted fixes beat complex architectures**: 0-param + 1-param > 688K params. +7. **Quant gap is the real metric**: Optimize for post-quantization, not raw. + +## 14.2 Phase 3b Lessons (exp27b–clean_54b) + +8. **Residual norm control is high leverage**: RMSNorm after skip connections. +9. **Stacking orthogonal improvements works**: Sub-additive but substantial. +10. **Auxiliary losses fatal in compute-starved regimes**: Every gradient must reduce CE. +11. **RoPE base changes hurt quantization**: Different landscape = harder-to-compress weights. +12. **Long context is a dead end**: Loss flat after 256 tokens. +13. **Size budget matters**: Check BEFORE celebrating. +14. **LR schedule was completely broken**: Biggest single improvement (-0.0166 bpb). +15. **Layer sharing revives dead blocks**: Block 9 dead → shared block 3 revived it. +16. **More depth beats more width**: 10th block > 8 KV heads. +17. **Activation capping hurts**: Warmdown already smooths weights. +18. **AWQ alpha undertested**: Re-sweep for each new best model. +19. **Warmdown=400 optimal**: 4 SWA checkpoints, proper decay. +20. **Simpler is better**: Stripping features HELPED convergence. +21. **Resid-norm redundant with warmdown**: 7ms/step overhead not worth it. +22. **XSA last 2 is sweet spot**: Model wants everywhere but overhead too high on 1×H100. +23. **zstd > LZMA**: Better for structured quantized weights. +24. **torch.compile mode matters**: `max-autotune-no-cudagraphs` gives best tradeoff. + +## 14.3 Phase 3.5–3.6 Lessons (exp60–exp87) + +25. **Partial RoPE 16/64 universally good**: 41% less quant error, head specialization. +26. **Cascading VR creates value highway**: Natural deep-layer pattern. +27. **Diagnostics are invaluable**: exp83 discovered 7 insights informing 4 experiments. +28. **The model tells you what it wants**: Listen to learned parameters. +29. **Don't fight Muon's orthogonal constraint**: VR_INIT must be 0.5, embed pre-init fails. +30. **MiLe/focal/JEPA all fail**: Loss reweighting doesn't work in limited steps. +31. **Architectural changes DO work**: Partial RoPE, cascading VR, x0-to-V — all positive. +32. **Quantization improvements are free**: int6 for MLP proj, clip search — zero training cost. + +## 14.4 Phase 3b-Muon Lessons + +33. **Parallel Muon gives 12% speed** but per-step convergence slightly worse. +34. **Double-buffering needs sufficient buffers** for grad accumulation steps. +35. **Custom Triton kernels rarely help** — torch.compile already fuses elementwise ops. +36. **AWQ needs real activation statistics** — weight-magnitude proxy doesn't work. +37. **Init order matters for reproducibility**. + +## 14.5 Phase 3c Lessons (exp92–exp119) + +38. **Meta-TTT ceiling is architecture-limited**: TTT delta invariant at ~0.023 regardless of optimizer. +39. **Block sharing fails at encoder/decoder boundary**: Decoder positions → dead. +40. **Position-conditional bigram hashing**: Zero-parameter +0.001 bpb trick. +41. **Tied embedding bottleneck is real but hard to fix**: embed_dim changes bust budget. + +## 14.6 Top-Level Synthesis + +1. **Loss reweighting doesn't work in 7K steps** (MiLe, focal, JEPA, position ramp — all failed) +2. **Architectural changes DO work** (partial RoPE, cascading VR, x0-to-V, XSA-all — all positive) +3. **Quantization improvements are free bpb** (int6 for MLP proj, clip search — zero training cost) +4. **Don't fight the optimizer** (Muon's orthogonal constraint is a feature; VR_INIT and embed_lr must respect it) +5. **Diagnostics are invaluable** (exp83 discovered 7 insights that informed 4 subsequent experiments) +6. **The model knows what it wants** (block 0 attention dies = structural, VE scale at block 8 grows to 0.88 = model wants identity there) + +--- + +# 15. TLDR: Top 20 Learnings Across All Phases + +1. **Steps > everything else.** Cutting grad_accum from 8→2 doubled optimizer updates in the same wallclock — biggest single win in Phase 3a. Every ms/step matters when you only get 600 seconds. + +2. **Fix your LR schedule before anything else.** ITERATIONS=20000 with 600s wallclock meant warmdown never fired. Fixing to ITERATIONS=1300 gave -0.017 bpb for free (exp34b). The model was training at max LR for 100% of training. + +3. **Depth recurrence is the best parameter-efficiency trick (community).** 3-layer recurrence (blocks 3-5, 2 extra passes) from the community SP8192 baseline gives 17 virtual layers from 11 physical — the single biggest architectural win. Only works within the encoder, NOT across encoder/decoder boundary. + +4. **SP8192 tokenizer is transformative (community).** Community's jump from SP1024 to SP8192 unlocked ~0.04 bpb improvement. But the larger embedding table (8192×512) needs GPTQ with SDClip — naive int8+brotli gives 10× worse quant degradation. + +5. **Parallel residuals improve quantization for free (community).** GPT-J-style two-lane routing (attn/MLP read same input) from the community baseline collapses the quant gap vs single-lane. Cross-lane accumulation (community ImprovedParallelResiduals, PR #1523) pushed this further to 1.0744. + +6. **Meta-TTT has an architecture-limited ceiling.** 4 experiments (exp101, 105a, 106, 107) show identical TTT delta ~0.023 bpb regardless of inner-loop optimizer (SGD, MetaSGD, SAM, none). The ceiling is set by bank architecture, not training. + +7. **Auxiliary losses are fatal in compute-starved regimes.** JEPA, focal loss, boundary boost, MTP — every auxiliary objective tested hurt. With 1200-4700 steps, every gradient must directly reduce CE loss. + +8. **Don't fight the optimizer.** Muon's orthogonal constraint is a feature. VR_INIT must be 0.5 (lower → negative alphas). Embed LR ratio is misleading because Muon normalizes gradient direction. Progressive unfreezing prevents co-adaptation. + +9. **Quantization improvements are free BPB.** Per-row clip search (-25% quant error), int6 for MLP proj (3.4× less error), GPTQ with SDClip — all zero training cost. Always sweep AWQ alpha for each new best model. + +10. **Simpler is better.** Stripping token-type embedding and loss weighting from exp53b actually HELPED. Fewer competing objectives = better convergence in limited steps. + +11. **QK_GAIN_INIT=5.25 is a free win (community).** Monotonic improvement from 4.0→5.25 observed in the community SP8192 baseline. Per-head query gain initialization helps attention patterns specialize faster. + +12. **Partial RoPE 16/64 is universally good.** Frees 75% of head dims for semantic matching, reduces quantization outliers 3×, and improves word-start attention. Consistent across every experiment it was tested in. + +13. **Word-start tokens dominate total loss.** 25-40% of tokens but 42-66% of total loss. Mean loss 3.6-5.1 vs 1.2-1.6 for continuations. The best fix is architectural (partial RoPE), not loss manipulation (focal, weighting). + +14. **Layer sharing revives dead blocks.** Block 9 was dead at 6.1% effective rank. Sharing block 3 at position 9 revived it to 10.3%. Fewer unique blocks = smaller artifact = more headroom for params. + +15. **Resid-norm is redundant with warmdown.** Adding RMSNorm after skip connections improves quant but costs ~7ms/step (19 fewer training steps). With proper LR warmdown, weights are already smooth enough. + +16. **Block sharing fails across encoder/decoder boundary.** Shared blocks at decoder positions converge to near-zero scales — effectively dead. Soft gates correctly diagnose the problem but can't override it (exp109). + +17. **The model tells you what it wants.** Block 0 attention dies (structural, MLP-dominant). Block 8 ve_scale grows to 0.88 (wants identity in deep-layer values). Bigram scale decays 0.26→0.10 (attention supersedes local patterns). Listen to the learned parameters. + +18. **Co-occurrence QK initialization works.** Initializing W_Q/W_K from bigram SVD gives meaningful step-0 attention patterns instead of random noise. Validated at 1.3525 bpb on 1×H100. + +19. **Warmdown timing is critical.** warmdown=400 steps (start at step 900) gives 4 SWA checkpoints and proper LR decay. Too late (warmdown=200) → only 2 checkpoints. Community uses 3500-4000 iters on longer runs. + +20. **Size budget is a hard constraint — check BEFORE celebrating.** embed_dim=448 achieved great BPB (1.0877) but at 16.28MB — over the 16MB limit. embed_dim=416 similar story at 16.44MB. Multiple experiments wasted on approaches that couldn't fit. + +--- + +# 16. Appendix: Summary Statistics + +## 16.1 Phase 3a (exp00–exp18) + +19 experiments. Best: exp09/exp13 at quant bpb 1.3145. + +## 16.2 Phase 3b (exp27b–clean_54b) + +~30 experiments. Best: exp54b at quant bpb 1.2708. + +## 16.3 Phase 3.5–3.6 (exp60–exp87) + +| Outcome | Count | +|---------|:-----:| +| 🟢 Positive | 8 | +| 🟡 Neutral | 10 | +| 🔴 Negative | 10 | + +**Success rate: 29% positive, 36% neutral, 36% negative** + +## 16.4 Phase 3b-Muon (exp70_parallel_muon–exp91) + +6 experiments. 2 positive, 2 neutral, 2 negative. + +## 16.5 Phase 3c (exp92–exp119 + community) + +| Outcome | Count | Examples | +|---------|:-----:|---------| +| 🟢 Positive | 8 | exp92, exp93, exp95, exp101, SP8192_3LayerRecur, WiderEmb, ImprovedParallelResiduals, CooccurrenceQKInit | +| 🟡 Neutral | 14 | exp96, exp98, exp99, exp105a, exp106, exp108, exp110, exp111, exp113, exp115, exp119, FiLM variants, recurrence variants | +| 🔴 Negative | 7 | exp107, exp109, exp112, exp114, exp116, exp117, exp118 | + +**Success rate: 28% positive, 48% neutral, 24% negative** + +## 16.6 Overall + +~119+ experiments across all phases. Overall positive rate ~28-29%. + +--- + +# 17. Complete Experiment Index + +Every experiment across all phases in one table. + +| # | Experiment | Base | Motivation | Result | Learning | +|:-:|-----------|------|-----------|:------:|---------| +| | **Phase 3a (exp00–exp18)** | | | | | +| 1 | exp00 (baseline-rerun) | exp27 | Establish baseline on A100 | Baseline | Quant bpb 1.3389; bigram.proj has worst quant error | +| 2 | exp01b (ln-scale-only) | exp27 | Test layer-norm damping | Not run | — | +| 3 | exp01c (ema-only) | exp27 | Test EMA weight averaging | Not run | — | +| 4 | exp01d (xsa-only) | exp27 | Test cross-sequence attention | Negative | XSA slows steps without quality gain | +| 5 | exp02 (speed-bigramfp16-awq) | exp00 | FP16 bigram + per-category AWQ | Negative | FP16 bigram blows artifact to 17.3MB | +| 6 | exp03 (qat-ste) | exp00 | Quantization-aware training via STE | Negative | QAT-STE destabilizes training; worst result | +| 7 | exp04 (no-cyclic-momentum) | exp00 | Test fixed momentum=0.95 | Negative | Cyclic momentum is slightly helpful as regularization | +| 8 | exp05 (grad-accum4) | exp00 | Double step count via accum 8->4 | **Positive** | Major breakthrough: 2x more steps = first sub-1.32 quant bpb | +| 9 | exp06 (swa-awq-accum2) | exp05 | Push accum to 2; SWA+AWQ tuning | **Positive** | Raw bpb breaks below 1.30 for first time | +| 10 | exp07 (tighter-swa-awq) | exp06 | SWA_EVERY=150, AWQ=0.7 | Neutral | Smaller artifact, quant bpb identical; sweet spot is SWA_EVERY=100 | +| 11 | exp08 (ctx-freq-bias) | exp05 | Learned token burstiness bias (+1 param) | Neutral | Redundant with attention; smallest artifact at 15.0MB | +| 12 | exp09 (padignore-wordboost) | exp06 | Skip pad tokens + word-start boost | **Positive** | Best quant bpb (1.3145); 0+1 params beat 688K trigram params | +| 13 | exp10 (trigram-unigram) | exp09 | Trigram hash table + unigram bias | Neutral | Best raw bpb but quant bpb regresses — extra params compress poorly | +| 14 | exp11 (trigram-slim-awq07) | exp10 | Slim trigram dim=48, AWQ=0.7 | Negative | dim=48 too small, AWQ too aggressive; double regression | +| 15 | exp12 (trigram64-awq06) | exp10 | Middle-ground trigram dim=64 | Neutral | Better than exp11 but worse than exp09; hash collisions too frequent | +| 16 | exp13 (multihead-gate-bigram) | exp09 | K=2 hash heads + context gate | **Positive** | Tied best quant bpb; collision reduction real but impact negligible | +| 17 | exp14 (engram-multiorder) | exp13 | 1-5gram, 10 lookups/position | Negative | Shared n-gram embeddings cause destructive interference | +| 18 | exp15 (engram-3order) | exp14 | 1-3gram with orthogonal subspaces | Neutral | Better isolation but each subspace too small (~42 dims) | +| 19 | exp16 (jepa-aux) | exp15 | JEPA predictor MLP, MSE loss | Negative | Biggest regression; fixed hash targets provide adversarial gradient | +| 20 | exp17 (byte-engram) | exp16 | Byte boundary features | Negative | No gain; base too weak to evaluate | +| 21 | exp18 (separate-trigram64) | exp13 | Separate 64-dim trigram + projection | Neutral | 688K extra params don't survive quantization | +| | **Phase 3b-Part1 (exp27b–exp33b)** | | | | | +| 22 | exp27b (resid-norm) | exp09 | RMSNorm after skip connections | **Positive** | High-leverage: attacks root cause of quant error (norm growth 19.7->89.5) | +| 23 | exp28b (perlayer-quant) | exp09 | Variable bitwidth per layer | Negative | 16% MSE reduction but over 16MB budget | +| 24 | exp29b (lossweight-typemb) | exp09 | 1.5x word-start loss + token-type embed | **Positive** | Gradient redistribution + structural signal both help | +| 25 | exp30b (combo) | exp09 | Stack all 3 validated improvements | **Positive** | Phase 3b-Part1 SOTA (1.3156); sub-additive but substantial | +| 26 | exp31b (rope-50k) | exp30b | RoPE base 10k->50k | Negative | Best raw bpb but quant gap widens; net negative after quantization | +| 27 | exp32b (aux-boundary) | exp30b | Auxiliary word-boundary classifier | Negative | Gradient waste; token-type already provides structural signal | +| 28 | exp33b (alt-rope-ntk) | exp30b | Alternating RoPE bases + NTK | Neutral | Marginal; positional loss still flat after 256 tokens | +| | **Phase 3b-Part2 (exp34b–exp48b)** | | | | | +| 29 | exp34b (lr-schedule-fix) | exp30b | Fix ITERATIONS 20000->1300 so warmdown fires | **Positive** | Single biggest improvement (-0.0166 bpb); warmdown was never firing | +| 30 | exp35b (focal-loss) | exp30b | Focal loss gamma=2 | Negative | Too aggressive; suppresses easy token gradients | +| 31 | exp36b (cappedact-labelsmooth) | exp30b | Activation cap + label smoothing | Negative | Both changes hurt independently and together | +| 32 | exp37b (fused-cap) | exp34b | Activation cap only | Negative | Cap hurts raw quality more than it helps quant | +| 33 | exp38b (speed-opt) | exp34b | Speed optimization | Neutral | Failed (OOM) | +| 34 | exp39b (swa-tuning) | exp34b | SWA parameter sweep | **Positive** | SWA_EVERY=100 confirmed optimal | +| 35 | exp42b (revive-block9) | exp34b | Share block 3 at position 9 | **Positive** | Dead block 9 (6.1% rank) revived to 10.3% | +| 36 | exp43b (boundary-boost) | exp42b | Boundary loss boost | Neutral | Too sparse (2.5% of positions) to matter in 1200 steps | +| 37 | exp44b (seqlen-curriculum) | exp42b | Sequence length curriculum | Negative | Speed regression | +| 38 | exp45b (awq-alpha07) | exp42b | AWQ alpha sweep (post-train) | Neutral | Alpha=0.7 gave -0.007 bpb free on exp42b | +| 39 | exp46b (full-mha) | exp42b | 8 KV heads (double from 4) | Neutral | Extra params but slower; depth > width | +| 40 | exp47b (warmdown200) | exp42b | Shorter warmdown=200 | Negative | Too late; only 2 SWA checkpoints vs 4 with warmdown=400 | +| 41 | exp48b (10blocks-depth) | exp42b | Add 10th unique block | **Positive** | Depth > width confirmed; quant bpb 1.2930 | +| | **Phase 3b-Part3 (exp53b–clean_54b)** | | | | | +| 42 | exp53b (lean-combo) | exp48b | Strip token-type + loss weighting | **Positive** | Removing features HELPED; quant bpb 1.2720 (-0.021!) | +| 43 | exp54b (xsa-zstd-ckfix) | exp53b | XSA last 2 layers + c_k fix + zstd | **Positive** | 1xH100 SOTA: quant bpb 1.2708 | +| 44 | exp55b (scaled-xsa-all) | exp54b | Learned XSA alpha on all layers | Neutral | Model wants XSA everywhere (alpha=0.75-0.99) but 20ms overhead | +| 45 | exp56b (fast-cosine-xsa) | exp55b | Cosine-scale XSA approximation | Negative | GQA head expansion is bottleneck, not XSA math | +| 46 | exp57b (lora-ttt) | exp54b | LoRA-based TTT | Negative | Failed | +| 47 | exp58b (resid-norm-on) | exp54b | Re-enable resid-norm | Negative | Redundant with warmdown; 7ms/step overhead not worth it | +| 48 | exp59b (pre-norm-skip) | exp54b | Pre-skip normalization | Negative | Same overhead as full resid-norm, no quality difference | +| 49 | clean_54b (final-arch) | exp54b | Clean submission version + TTT | **Positive** | Quant bpb 1.2723; clean baseline | +| 50 | clean_54b_v2 (bf16-roundtrip) | clean_54b | BF16 roundtrip test | Negative | Destroyed quality | +| | **Phase 3.5 (exp60–exp80)** | | | | | +| 51 | exp60 (8xh100-sim) | exp54b | EMA + flash_attn3 + 8xH100 simulation | Neutral | Infrastructure for scaling; not a bpb experiment | +| 52 | exp61b (xsa-all-warmdown) | exp60 | XSA all blocks + cosine warmdown | **Positive** | Pre-quant 1.1504; XSA-all works at scale | +| 53 | exp63 (cascade-vr) | exp61b | Cascading value residual + adaptive warmdown | **Positive** | Pre-quant 1.1377; discovered deep-layer value highway | +| 54 | exp64 (mlp-int6) | exp63 | MLP int6 quantization | Not run | Superseded by exp69 | +| 55 | exp65 (quant-overhaul) | exp63 | Full quantization overhaul | Not run | Ideas flowed into exp69 | +| 56 | exp66 (mile-nope) | exp65 | MiLe loss + partial NoPE | Negative | MiLe hurts early convergence | +| 57 | exp67 (ws-semantic-attn) | exp66 | Word-start semantic attention | Negative | Failed | +| 58 | exp68 (ws-mtp) | exp66 | Next-word-start MTP head | Not run | TTT data leakage concern | +| 59 | exp69 (better-quant) | exp63 | MLP proj->int6, attn->int5, LZMA, prune 5% | **Positive** | Closed quant gap 0.035->0.015; free improvements | +| 60 | exp70 (speed-opt) | exp69 | Batched NS5, EMA/10, set_to_none, deferred .item() | **Positive** | ~1.15 bpb; speed-optimized foundation for all subsequent | +| 61 | exp71 (output-bias) | exp70 | Output bias + label smooth + Z-loss | Not run | Needs too many steps to build momentum | +| 62 | exp72 (jepa-concept) | exp70 | JEPA concept loss | Negative | Added overhead, not enough steps even at 7K | +| 63 | exp73 (warmdown-focal) | exp70 | Warmdown focal + TTT weight | Not run | Safe late-training intervention (designed) | +| 64 | exp74 (prope-qgain-wbigram) | exp70 | Partial RoPE 16/64 + diverse q_gain + word bigram | **Positive** | Sliding bpb 1.1456; heads specialized (sharp+soft) | +| 65 | exp75 (word-pool) | exp74 | Inject previous word-start embedding | Negative | Model suppressed it (scale 0.1->0.002); redundant with attention | +| 66 | exp76 (dual-word-attn) | exp74 | Dual token + word attention | Negative | Failed | +| 67 | exp77old (late-warmdown) | exp70 | Late warmdown only | Neutral | Superseded by exp77 | +| 68 | exp77 (progressive-batch) | exp70 | Progressive batch + seq_len curriculum | Not run | Theoretically sound but non-standard | +| 69 | exp78 (ws-loss-curriculum) | exp70 | Word-start loss curriculum 0.1->1.0 | **Positive** | Best embedding quality; WS rank improved | +| 70 | exp79 (position-ramp) | exp70 | Position ramp 1.0->1.2 + late WS boost | Negative | Premise wrong: late positions are EASIER (90% repeats) | +| 71 | exp80 (best-stack) | exp70 | Combine pRoPE + bigram-after-norm + pos ramp + clamp | Negative | Bigram-after-norm destabilized attention | +| | **Phase 3.6 (exp81–exp87)** | | | | | +| 72 | exp81 (prope-ws-curriculum) | exp78 | Partial RoPE + WS curriculum | Neutral | Failed | +| 73 | exp82 (drop-layer10) | exp81 | Drop layer 10 + diverse q_gain | Not run | Designed only | +| 74 | exp83 (diagnostics) | exp70 | Full diagnostic run: grad norms, VR health, block analysis | **Positive** | 7 actionable insights; premature warmdown, dead blocks identified | +| 76 | exp84 (diagnostic-tuned) | exp83 | Apply diagnostics: VR_init=0.3, embed_lr=0.015 | Negative | VR went negative; embed_lr ratio misleading with Muon | +| 77 | exp85 (community-derived) | exp83 | pRoPE + x0-to-V + LN scale + clip search + small bigram | **Positive** | Best pre-quant (1.1517); ve_scale revealed model preferences | +| 78 | exp86 (deep-opt) | exp85 | Fused QKV + int8 critical + TF32 | Not run | Designed | +| 79 | exp87 (fast-convergence) | exp85 | Embed preinit SVD + progressive unfreeze + block9 AdamW | Negative | All 3 hurt; don't fight Muon's orthogonal constraint | +| | **Phase 3b-Muon (parallel optimizer)** | | | | | +| 80 | exp70_parallel_muon | exp70 | Parallel Muon via reduce-scatter/all-gather overlap | **Positive** | 12% speed (658ms vs 750ms); same final bpb | +| 81 | exp70_vram_opt | exp70_parallel_muon | Double-buffer data loader | Negative | Insufficient buffers for grad_accum | +| 82 | exp70_cuda_fused | exp70_parallel_muon | CUDA Graphs + Triton fusion | Negative | No improvement | +| 83 | exp90 (copy-head) | exp70_parallel_muon | TopicCopyHead (hybrid freq+attn) | Neutral | Concept validated; 40ms overhead | +| 84 | reverted_exp70 | exp70_parallel_muon | Clean base with all fixes | **Positive** | Clean foundation; 656ms/step | +| 85 | exp91 (smooth-v0residual) | reverted_exp70 | V0 residual + label smoothing | Neutral | Pending validation | +| | **Phase 3c (exp92–exp109)** | | | | | +| 86 | exp92 (banks-asyncmuon) | exp70 | Major rewrite: bank tensors + async Muon + partial RoPE + QAT + VE | **Positive** | ~1.131 bpb; paradigm shift in architecture | +| 87 | exp93 (meta-ttt) | exp92 | Meta-TTT inner/outer FOMAML | **Positive** | Legal_ttt ~1.116; first meta-TTT integration | +| 88 | exp95 (size-opt-metattt2x) | exp93 | Size optimization + meta-TTT 2x | **Positive** | Legal_ttt 1.1169; SOTA at the time | +| 89 | exp96 (warmdown-trigram) | exp95 | Warmdown fix + trigram hash | Neutral | ~1.135 bpb; marginal | +| 90 | exp97 (fp8-pipeline) | exp96 | FP8 pipeline + compile | Not run | Designed | +| 91 | exp98 (metattt-randomsplit) | exp96 | Random-split FOMAML + momentum LR match | Neutral | ~1.135 bpb; no improvement | +| 92 | exp99 (tripleloop) | exp98 | Triple loop + parallel residuals | Not run | Community merged first | +| 93 | exp100 (half-metattt) | exp95 | Half meta-TTT variant | Neutral | Not tracked in detail | +| 94 | exp101 (poscond-bigram) | exp95 | Position-conditional bigram hash by token class | **Positive** | Legal_ttt 1.11588; zero-param trick splitting hash by word-start | +| 95 | exp105a (no-metattt ablation) | exp101 | Remove meta-TTT to measure its contribution | Neutral | Meta-TTT = +0.00036 bpb (noise); ceiling is architectural | +| 96 | exp106 (metasgd-crosschunk) | exp101 | MetaSGD + cross-chunk FOMAML | Neutral | TTT delta invariant at ~0.023; ceiling confirmed | +| 97 | exp107 (sam-inner) | exp106 | SAM inner loop for TTT | Negative | SAM hurts; TTT delta still ~0.023 regardless of optimizer | +| 98 | exp108 (sp8192-brotli) | exp106 | SP8192 tokenizer + Brotli compression | Neutral | No stored results | +| 99 | exp109 (shared-blocks-softgate) | exp101 | Block sharing K=8 + soft gates + SP8192 | Negative | Decoder positions dead (near-zero scales); 10x worse quant | +| | **Community SOTA (SP8192+)** | | | | | +| 100 | SP8192_3LayerRecur (community) | Community | SP8192 + 3-layer recurrence (blocks 3-5) + parallel residuals + QK_GAIN=5.25 | **Positive** | Legal_ttt 1.0808; paradigm shift — 17 virtual layers from 11 physical | +| 101 | WiderEmb_TapInV6_TTT (community) | Community | Wider loop (3x3) + per-pass embeddings + Tap-In V6 + legal TTT | **Positive** | Legal_ttt 1.0788 (3-seed mean 1.078825) | +| 102 | ImprovedParallelResiduals (community PR #1523) | Community | Cross-lane attn/MLP accumulation + CUTLASS EVT fusion | **Positive** | **Legal_ttt 1.0744** — CURRENT BEST; 71 bytes headroom | +| 103 | RecurStepFiLM_PooledRetrieval (community) | Community | FiLM conditioning + pooled retrieval | Neutral | No improvement over base | +| 104 | 10L_RecurStepFiLM_PooledRetrieval (community) | Community | 10L variant of FiLM+retrieval | Neutral | No improvement | +| 105 | newSota (community) | Community | Community SOTA integration | **Positive** | Integration checkpoint | +| 106 | 11L_RecurStep3_loopedonly | Community | 11L, recurrence step 3, looped-only | Neutral | No improvement over ImprovedParallelResiduals | +| 107 | 11L_RecurStep3_loops3 | Community | 11L with 3 loops | Neutral | No improvement | +| 108 | 11L_RecurStep_StochDepth_ProgLoop | Community | Stochastic depth + progressive loop | Neutral | No improvement | +| 109 | 11L_RecurStep_StochDepth_ProgLoop_KVCache | Community | + KV cache for recurrence | Neutral | No improvement | +| 110 | 11L_Block10MLPHalf_RecurStepFiLM | Community | Block 10 MLP halved + FiLM + retrieval | Neutral | No improvement | +| 111 | loop_in_SP8192_3LayerRecur | Community | Loop detection: timestep embed + re-injection + per-loop RMSNorm | Neutral | Not yet trained | +| | **Frontier (exp110–exp119)** | | | | | +| 112 | exp110 (perlayer-quant-trigram) | ImprovedParallelResiduals | Per-layer quant + trigram + PARALLEL_START=7 | Neutral | No improvement | +| 113 | exp111 (lora-ttt-shrunk) | ImprovedParallelResiduals | LoRA TTT rank=8 + shrunk block 10 MLP | Neutral | No improvement | +| 114 | exp112 (grad-rescaling) | ImprovedParallelResiduals | Gradient rescaling on weak blocks | Negative | Doesn't fix structural tied-embedding bottleneck | +| 115 | exp113 (drop-l0-mtp) | ImprovedParallelResiduals | Drop L0 MLP + batch schedule + MTP | Neutral | Truncated logs | +| 116 | exp114 (embed384-decouple) | ImprovedParallelResiduals | embed_dim=384 to decouple boundary blocks | Negative | 655K param loss -> BPB regression (1.0950) | +| 117 | exp115 (embed384-asymmetric) | ImprovedParallelResiduals | embed_dim=384 + drop boundary MLPs | Neutral | Truncated | +| 118 | exp116 (embed384-no-x0) | ImprovedParallelResiduals | embed_dim=384 + remove x0 pathway | Negative | No stored results | +| 119 | exp117 (embed448-tuned) | ImprovedParallelResiduals | embed_dim=448 to activate boundary blocks | Negative | Good BPB (1.0877) but 16.28MB — over budget | +| 120 | exp118 (embed416-parstart7) | ImprovedParallelResiduals | embed_dim=416 + parallel_start=7 + tighter clip | Negative | Good BPB (1.0915) but 16.44MB — over budget | +| 121 | exp119 (residual-lowrank-proj) | ImprovedParallelResiduals | Residual low-rank projection (rank=32) | Neutral | Theoretically correct fix; not run to completion | +| | **Misc** | | | | | +| 122 | CooccurrenceQKInit | PR #623 | Init W_Q/W_K from bigram co-occurrence SVD | **Positive** | Val_bpb 1.3525 on 1xH100; meaningful step-0 attention patterns | + +--- + +*Last updated: 2026-04-13* diff --git a/records/track_non_record_16mb/parameter-golf-experimentations-summary/index.html b/records/track_non_record_16mb/parameter-golf-experimentations-summary/index.html new file mode 100644 index 0000000000..e692eb1312 --- /dev/null +++ b/records/track_non_record_16mb/parameter-golf-experimentations-summary/index.html @@ -0,0 +1,1465 @@ + + + + + +Mind Of Experimenter + + + + +
+

Mind Of Experimenter

+
+
+ + +
+
+
All
+
Positive
+
Negative
+
Neutral
+
Architecture
+
Optimization
+
Quantization
+
Evaluation
+
+
+
+ + +
+
+ +
+
+ +
+
+ +
+
+
+
+ +
+
+
+ + +
+
+
+
+
+
+ +
+ +
+
Positive
+
Negative
+
Neutral
+ + + +
+ +
+ +
+ + + +
+ + + + From 01ba47ec37395a4183e99841fe95400f2239455b Mon Sep 17 00:00:00 2001 From: SPThole Date: Tue, 14 Apr 2026 00:21:47 +0530 Subject: [PATCH 10/10] updt in readme --- .../STRUCTURED_EXPSUM.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md b/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md index a87a65f8e4..de93869df3 100644 --- a/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md +++ b/records/track_non_record_16mb/parameter-golf-experimentations-summary/STRUCTURED_EXPSUM.md @@ -1,5 +1,19 @@ # Structured Experiment Summary +This is kind of a summary of (maybe not all, lol) the experimentation I did — lots of learning. Thanks to OpenAI for the $525 + a few hundred dollars of my own credits! I’d love to try more ideas I had but couldn’t due to a shortage of credits. + +In this journey, I tried not to get bogged down by leaderboard approaches as much as possible. In a few places, though, when I got stuck, I did take help from the community. My general approach was: train a model → analyze it → try to solve the issues observed in the analysis. This ended up costing me many experiments and dollars. + +GIT REPO TO FIND ALL EXPERIMENTS: +https://github.com/SPThole/parameter-golf-experimentations + +I have also made a cool mind map of all the experimentation — basically the path of what I did and why. I’ve also attached lineages that are relevant from community discussions and leaderboard files. + +I am planning to build on this: +https://github.com/SPThole/bpb_wtf or visit: https://bpb-wtf.vercel.app/ + +I’m also building a broader direction around this (mind map + experiments). If this resonates with anyone or you’d like to collaborate, feel free to reach out — I’d love to explore this further together. + > **Competition**: OpenAI Parameter Golf > **Objective**: Minimize validation loss (bits-per-byte, bpb) under a 16MB artifact constraint within 10-minute training on 8×H100 > **Total experiments**: 119+