From 9d318e7d1983ffecb465738bfafae4abfcb61ca7 Mon Sep 17 00:00:00 2001 From: aquariouseworkman Date: Thu, 19 Mar 2026 03:33:51 -0400 Subject: [PATCH 1/3] Record: Seq4096 + Sliding Window Eval, val_bpb=1.1808 --- .../2026-03-19_Seq4096SlidingWindow/README.md | 76 + .../submission.json | 11 + .../2026-03-19_Seq4096SlidingWindow/train.log | 1627 +++++++++++++++++ .../train_gpt.py | 1274 +++++++++++++ 4 files changed, 2988 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md create mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json create mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log create mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md new file mode 100644 index 0000000000..a51057cd53 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md @@ -0,0 +1,76 @@ +# Training Opt Seq4096 + Sliding Window Eval (stride=64) + +This record combines two independent improvements over the naive baseline: + +1. **Longer training context (seq_len=4096) with aggressively tuned Muon optimizer** +2. **Sliding window evaluation (stride=64) for near-full context on every scored token** + +## Configuration + +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- Tied output/input embeddings: `TIE_EMBEDDINGS=1` +- Sequence length: `TRAIN_SEQ_LEN=4096` +- Batching: `TRAIN_BATCH_TOKENS=393216` (3/4 batch for more optimizer updates per second) +- Learning rates: `TIED_EMBED_LR=0.030 MATRIX_LR=0.020 SCALAR_LR=0.020` +- Muon optimizer: `MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_STEPS=1500 MUON_MOMENTUM_WARMUP_START=0.92` +- Schedule: `WARMDOWN_ITERS=3000` +- Eval: `EVAL_STRIDE=64` (sliding window, each scored token gets 4032 tokens of context) + +## Command + +```bash +RUN_ID=combined_seq4096_sw64 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_LOG_EVERY=50 \ +VAL_LOSS_EVERY=1000 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Key Metrics + +- Timed training stopped at `11478/20000` steps due to the wallclock cap. +- Pre-quant eval at stop: `val_loss:2.0071`, `val_bpb:1.1887` +- Post-quant roundtrip eval (standard): `val_loss:2.0142`, `val_bpb:1.1929` +- **Post-quant sliding window eval: `val_loss:1.9937`, `val_bpb:1.1808`** +- Exact printed metric: `final_sliding_window_eval_exact stride:64 val_loss:1.99374987 val_bpb:1.18081285` +- Train time: `599949ms` (`step_avg:52.27ms`) +- Sliding window eval time: `278940ms` (under 10-minute eval budget) +- Peak memory: `7655 MiB allocated`, `8156 MiB reserved` +- Serialized model int8+zlib: `15824445 bytes` +- Code size: `54433 bytes` +- Total submission size int8+zlib: `15878878 bytes` + +## Approach + +This submission stacks three orthogonal improvements: + +### 1. Longer training context (seq_len=4096) +Each training sequence sees 4x more context than the 1024-token baseline, giving the autoregressive model much better signal per token. This costs ~52ms/step (vs ~43ms at seq_len=1024), but the quality improvement far outweighs the fewer total steps (11,478 vs ~13,780). + +### 2. Aggressive Muon optimizer tuning +- **Higher momentum (0.99 vs 0.95):** Stronger gradient smoothing for better convergence. +- **Lower learning rates (0.020 vs 0.04):** Dramatically reduces int8 quantization loss (0.004 BPB quant penalty vs 0.007+ at default LR) while maintaining similar pre-quant quality. +- **3/4 batch (393K vs 524K tokens):** More optimizer updates per wallclock second. +- **Extended momentum warmup (1500 steps from 0.92):** Prevents early instability with the higher momentum. +- **Longer warmdown (3000 steps):** Proportionally longer LR decay for the ~11,500-step run. + +### 3. Sliding window evaluation (stride=64) +Standard evaluation chops the validation set into non-overlapping 4096-token blocks, meaning the first tokens in each block get minimal context. Sliding window evaluation shifts the window by only 64 tokens at a time, scoring only the last 64 tokens per window. Each scored token sees 4032 tokens of prior context, dramatically reducing the "cold start" penalty. This improves BPB by ~0.012 over standard post-quant eval (1.1808 vs 1.1929). + +The evaluation uses a `forward_logits()` method that returns logits without computing loss, allowing us to score only the relevant trailing tokens per window. The sliding window runs on 8 GPUs in parallel with batched windows (32 per forward pass) and completes in ~279 seconds. + +## Training Volume + +- Global batch: `393216` tokens/step +- Total train tokens seen: `~4.5B` (11,478 steps × 393,216 tokens) +- Dataset: 20 shards of fineweb10B_sp1024 + +## Included Files + +- `train_gpt.py` (standalone script with all changes baked in) +- `README.md` (this file) +- `submission.json` (leaderboard metadata) +- `train.log` (full training log from the canonical run) diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json new file mode 100644 index 0000000000..9df0d68362 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json @@ -0,0 +1,11 @@ +{ + "author": "aquariouseworkman", + "github_id": "aquariouseworkman", + "name": "Seq4096 Muon Tuned + Sliding Window Eval", + "blurb": "SP-1024 9x512 KV4 trained at seq_len=4096 with tuned Muon optimizer (momentum 0.99, LR 0.020, 3/4 batch, warmdown 3000). Sliding window eval (stride=64) gives each scored token 4032 tokens of context. Post-quant sliding window val_bpb=1.1808.", + "date": "2026-03-19T07:00:00Z", + "val_loss": 1.99374987, + "val_bpb": 1.18081285, + "bytes_total": 15878878, + "bytes_code": 54433 +} diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log new file mode 100644 index 0000000000..31ee9659b8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log @@ -0,0 +1,1627 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 19 06:53:44 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1923 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 1924 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 1925 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 1926 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 1927 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 1928 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 1929 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 1930 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:20 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9376 train_time:33ms step_avg:32.90ms +step:2/20000 train_loss:12.1428 train_time:88ms step_avg:43.95ms +step:3/20000 train_loss:7.4539 train_time:147ms step_avg:49.10ms +step:4/20000 train_loss:6.3261 train_time:207ms step_avg:51.67ms +step:5/20000 train_loss:6.9176 train_time:263ms step_avg:52.65ms +step:6/20000 train_loss:6.9130 train_time:325ms step_avg:54.17ms +step:7/20000 train_loss:6.7021 train_time:380ms step_avg:54.27ms +step:8/20000 train_loss:6.6697 train_time:440ms step_avg:54.97ms +step:9/20000 train_loss:6.3152 train_time:497ms step_avg:55.19ms +step:10/20000 train_loss:6.2095 train_time:556ms step_avg:55.56ms +step:50/20000 train_loss:4.1814 train_time:2618ms step_avg:52.36ms +step:100/20000 train_loss:3.3915 train_time:5201ms step_avg:52.01ms +step:150/20000 train_loss:3.0600 train_time:7774ms step_avg:51.83ms +step:200/20000 train_loss:2.7600 train_time:10349ms step_avg:51.75ms +step:250/20000 train_loss:2.6022 train_time:12924ms step_avg:51.70ms +step:300/20000 train_loss:2.6200 train_time:15591ms step_avg:51.97ms +step:350/20000 train_loss:2.6057 train_time:18161ms step_avg:51.89ms +step:400/20000 train_loss:2.4191 train_time:20735ms step_avg:51.84ms +step:450/20000 train_loss:1.9559 train_time:23309ms step_avg:51.80ms +step:500/20000 train_loss:2.4657 train_time:25883ms step_avg:51.77ms +step:550/20000 train_loss:2.4548 train_time:28581ms step_avg:51.97ms +step:600/20000 train_loss:2.2913 train_time:31168ms step_avg:51.95ms +step:650/20000 train_loss:2.3561 train_time:33754ms step_avg:51.93ms +step:700/20000 train_loss:2.3823 train_time:36345ms step_avg:51.92ms +step:750/20000 train_loss:2.3120 train_time:38937ms step_avg:51.92ms +step:800/20000 train_loss:2.3685 train_time:41646ms step_avg:52.06ms +step:850/20000 train_loss:2.3882 train_time:44237ms step_avg:52.04ms +step:900/20000 train_loss:2.1988 train_time:46828ms step_avg:52.03ms +step:950/20000 train_loss:2.2929 train_time:49420ms step_avg:52.02ms +step:1000/20000 train_loss:2.3322 train_time:52014ms step_avg:52.01ms +step:1000/20000 val_loss:2.3041 val_bpb:1.3646 train_time:52038ms step_avg:52.04ms +step:1050/20000 train_loss:2.3307 train_time:54706ms step_avg:52.10ms +step:1100/20000 train_loss:2.3880 train_time:57297ms step_avg:52.09ms +step:1150/20000 train_loss:2.3096 train_time:59889ms step_avg:52.08ms +step:1200/20000 train_loss:2.3520 train_time:62485ms step_avg:52.07ms +step:1250/20000 train_loss:2.2906 train_time:65076ms step_avg:52.06ms +step:1300/20000 train_loss:2.4268 train_time:67768ms step_avg:52.13ms +step:1350/20000 train_loss:2.2162 train_time:70360ms step_avg:52.12ms +step:1400/20000 train_loss:2.2977 train_time:72956ms step_avg:52.11ms +step:1450/20000 train_loss:2.3528 train_time:75550ms step_avg:52.10ms +step:1500/20000 train_loss:2.4259 train_time:78146ms step_avg:52.10ms +step:1550/20000 train_loss:2.3027 train_time:80840ms step_avg:52.15ms +step:1600/20000 train_loss:2.2671 train_time:83436ms step_avg:52.15ms +step:1650/20000 train_loss:2.2047 train_time:86028ms step_avg:52.14ms +step:1700/20000 train_loss:2.2044 train_time:88625ms step_avg:52.13ms +step:1750/20000 train_loss:2.3344 train_time:91217ms step_avg:52.12ms +step:1800/20000 train_loss:2.0434 train_time:93920ms step_avg:52.18ms +step:1850/20000 train_loss:2.2787 train_time:96516ms step_avg:52.17ms +step:1900/20000 train_loss:2.2827 train_time:99110ms step_avg:52.16ms +step:1950/20000 train_loss:2.2277 train_time:101703ms step_avg:52.16ms +step:2000/20000 train_loss:2.3059 train_time:104296ms step_avg:52.15ms +step:2000/20000 val_loss:2.2113 val_bpb:1.3097 train_time:104328ms step_avg:52.16ms +step:2050/20000 train_loss:2.8644 train_time:106995ms step_avg:52.19ms +step:2100/20000 train_loss:2.1551 train_time:109592ms step_avg:52.19ms +step:2150/20000 train_loss:2.2515 train_time:112182ms step_avg:52.18ms +step:2200/20000 train_loss:2.0542 train_time:114772ms step_avg:52.17ms +step:2250/20000 train_loss:2.4315 train_time:117366ms step_avg:52.16ms +step:2300/20000 train_loss:2.1390 train_time:120068ms step_avg:52.20ms +step:2350/20000 train_loss:2.1749 train_time:122661ms step_avg:52.20ms +step:2400/20000 train_loss:2.2065 train_time:125264ms step_avg:52.19ms +step:2450/20000 train_loss:2.3101 train_time:127855ms step_avg:52.19ms +step:2500/20000 train_loss:2.2842 train_time:130447ms step_avg:52.18ms +step:2550/20000 train_loss:2.2999 train_time:133134ms step_avg:52.21ms +step:2600/20000 train_loss:2.3269 train_time:135727ms step_avg:52.20ms +step:2650/20000 train_loss:2.2548 train_time:138321ms step_avg:52.20ms +step:2700/20000 train_loss:2.1140 train_time:140909ms step_avg:52.19ms +step:2750/20000 train_loss:2.1784 train_time:143501ms step_avg:52.18ms +step:2800/20000 train_loss:2.3925 train_time:146200ms step_avg:52.21ms +step:2850/20000 train_loss:2.1333 train_time:148792ms step_avg:52.21ms +step:2900/20000 train_loss:2.1318 train_time:151387ms step_avg:52.20ms +step:2950/20000 train_loss:2.1459 train_time:153977ms step_avg:52.20ms +step:3000/20000 train_loss:2.1693 train_time:156571ms step_avg:52.19ms +step:3000/20000 val_loss:2.1549 val_bpb:1.2763 train_time:156597ms step_avg:52.20ms +step:3050/20000 train_loss:2.0718 train_time:159165ms step_avg:52.19ms +step:3100/20000 train_loss:2.0741 train_time:161860ms step_avg:52.21ms +step:3150/20000 train_loss:2.2964 train_time:164450ms step_avg:52.21ms +step:3200/20000 train_loss:2.1957 train_time:167038ms step_avg:52.20ms +step:3250/20000 train_loss:2.1752 train_time:169631ms step_avg:52.19ms +step:3300/20000 train_loss:2.0950 train_time:172227ms step_avg:52.19ms +step:3350/20000 train_loss:2.2272 train_time:174916ms step_avg:52.21ms +step:3400/20000 train_loss:2.1998 train_time:177509ms step_avg:52.21ms +step:3450/20000 train_loss:2.0994 train_time:180105ms step_avg:52.20ms +step:3500/20000 train_loss:2.1930 train_time:182703ms step_avg:52.20ms +step:3550/20000 train_loss:2.1412 train_time:185300ms step_avg:52.20ms +step:3600/20000 train_loss:2.0727 train_time:188011ms step_avg:52.23ms +step:3650/20000 train_loss:2.1659 train_time:190609ms step_avg:52.22ms +step:3700/20000 train_loss:2.1052 train_time:193207ms step_avg:52.22ms +step:3750/20000 train_loss:2.1134 train_time:195805ms step_avg:52.21ms +step:3800/20000 train_loss:2.0930 train_time:198401ms step_avg:52.21ms +step:3850/20000 train_loss:2.1268 train_time:201119ms step_avg:52.24ms +step:3900/20000 train_loss:2.0784 train_time:203717ms step_avg:52.24ms +step:3950/20000 train_loss:2.1538 train_time:206315ms step_avg:52.23ms +step:4000/20000 train_loss:1.9293 train_time:208915ms step_avg:52.23ms +step:4000/20000 val_loss:2.1208 val_bpb:1.2561 train_time:208940ms step_avg:52.24ms +step:4050/20000 train_loss:2.1241 train_time:211509ms step_avg:52.22ms +step:4100/20000 train_loss:2.1416 train_time:214210ms step_avg:52.25ms +step:4150/20000 train_loss:2.0450 train_time:216807ms step_avg:52.24ms +step:4200/20000 train_loss:2.0048 train_time:219404ms step_avg:52.24ms +step:4250/20000 train_loss:2.1380 train_time:221999ms step_avg:52.23ms +step:4300/20000 train_loss:2.1309 train_time:224595ms step_avg:52.23ms +step:4350/20000 train_loss:2.0464 train_time:227297ms step_avg:52.25ms +step:4400/20000 train_loss:2.1155 train_time:229888ms step_avg:52.25ms +step:4450/20000 train_loss:2.0994 train_time:232479ms step_avg:52.24ms +step:4500/20000 train_loss:2.1875 train_time:235075ms step_avg:52.24ms +step:4550/20000 train_loss:2.1034 train_time:237666ms step_avg:52.23ms +step:4600/20000 train_loss:2.0381 train_time:240368ms step_avg:52.25ms +step:4650/20000 train_loss:2.0448 train_time:242965ms step_avg:52.25ms +step:4700/20000 train_loss:2.1083 train_time:245559ms step_avg:52.25ms +step:4750/20000 train_loss:2.0594 train_time:248155ms step_avg:52.24ms +step:4800/20000 train_loss:2.0705 train_time:250751ms step_avg:52.24ms +step:4850/20000 train_loss:2.1384 train_time:253464ms step_avg:52.26ms +step:4900/20000 train_loss:2.1304 train_time:256063ms step_avg:52.26ms +step:4950/20000 train_loss:2.4399 train_time:258658ms step_avg:52.25ms +step:5000/20000 train_loss:2.1459 train_time:261255ms step_avg:52.25ms +step:5000/20000 val_loss:2.1024 val_bpb:1.2452 train_time:261287ms step_avg:52.26ms +step:5050/20000 train_loss:2.0569 train_time:263854ms step_avg:52.25ms +step:5100/20000 train_loss:2.1764 train_time:266552ms step_avg:52.27ms +step:5150/20000 train_loss:1.9255 train_time:269147ms step_avg:52.26ms +step:5200/20000 train_loss:2.1012 train_time:271751ms step_avg:52.26ms +step:5250/20000 train_loss:2.0996 train_time:274349ms step_avg:52.26ms +step:5300/20000 train_loss:2.0296 train_time:276947ms step_avg:52.25ms +step:5350/20000 train_loss:2.0501 train_time:279653ms step_avg:52.27ms +step:5400/20000 train_loss:2.2050 train_time:282252ms step_avg:52.27ms +step:5450/20000 train_loss:2.0719 train_time:284854ms step_avg:52.27ms +step:5500/20000 train_loss:2.0701 train_time:287454ms step_avg:52.26ms +step:5550/20000 train_loss:2.1761 train_time:290052ms step_avg:52.26ms +step:5600/20000 train_loss:2.0441 train_time:292763ms step_avg:52.28ms +step:5650/20000 train_loss:2.1013 train_time:295371ms step_avg:52.28ms +step:5700/20000 train_loss:2.1353 train_time:297974ms step_avg:52.28ms +step:5750/20000 train_loss:2.1457 train_time:300579ms step_avg:52.27ms +step:5800/20000 train_loss:2.1356 train_time:303181ms step_avg:52.27ms +step:5850/20000 train_loss:2.0896 train_time:305889ms step_avg:52.29ms +step:5900/20000 train_loss:2.2131 train_time:308488ms step_avg:52.29ms +step:5950/20000 train_loss:1.9779 train_time:311089ms step_avg:52.28ms +step:6000/20000 train_loss:2.0487 train_time:313682ms step_avg:52.28ms +step:6000/20000 val_loss:2.0880 val_bpb:1.2366 train_time:313709ms step_avg:52.28ms +step:6050/20000 train_loss:2.0857 train_time:316281ms step_avg:52.28ms +step:6100/20000 train_loss:2.0682 train_time:318877ms step_avg:52.27ms +step:6150/20000 train_loss:2.0564 train_time:321572ms step_avg:52.29ms +step:6200/20000 train_loss:2.0141 train_time:324175ms step_avg:52.29ms +step:6250/20000 train_loss:2.1140 train_time:326773ms step_avg:52.28ms +step:6300/20000 train_loss:2.0920 train_time:329369ms step_avg:52.28ms +step:6350/20000 train_loss:2.1695 train_time:331965ms step_avg:52.28ms +step:6400/20000 train_loss:2.0887 train_time:334656ms step_avg:52.29ms +step:6450/20000 train_loss:2.0925 train_time:337250ms step_avg:52.29ms +step:6500/20000 train_loss:2.1196 train_time:339844ms step_avg:52.28ms +step:6550/20000 train_loss:2.2899 train_time:342438ms step_avg:52.28ms +step:6600/20000 train_loss:2.0384 train_time:345033ms step_avg:52.28ms +step:6650/20000 train_loss:2.0170 train_time:347728ms step_avg:52.29ms +step:6700/20000 train_loss:2.1356 train_time:350323ms step_avg:52.29ms +step:6750/20000 train_loss:1.9176 train_time:352920ms step_avg:52.28ms +step:6800/20000 train_loss:1.9066 train_time:355518ms step_avg:52.28ms +step:6850/20000 train_loss:1.9774 train_time:358115ms step_avg:52.28ms +step:6900/20000 train_loss:2.0591 train_time:360817ms step_avg:52.29ms +step:6950/20000 train_loss:1.9967 train_time:363414ms step_avg:52.29ms +step:7000/20000 train_loss:2.0529 train_time:366003ms step_avg:52.29ms +step:7000/20000 val_loss:2.0760 val_bpb:1.2295 train_time:366035ms step_avg:52.29ms +step:7050/20000 train_loss:2.1090 train_time:368603ms step_avg:52.28ms +step:7100/20000 train_loss:2.1851 train_time:371201ms step_avg:52.28ms +step:7150/20000 train_loss:2.0422 train_time:373893ms step_avg:52.29ms +step:7200/20000 train_loss:2.1037 train_time:376483ms step_avg:52.29ms +step:7250/20000 train_loss:2.0128 train_time:379076ms step_avg:52.29ms +step:7300/20000 train_loss:2.1154 train_time:381668ms step_avg:52.28ms +step:7350/20000 train_loss:2.0027 train_time:384261ms step_avg:52.28ms +step:7400/20000 train_loss:2.1242 train_time:386965ms step_avg:52.29ms +step:7450/20000 train_loss:2.0912 train_time:389554ms step_avg:52.29ms +step:7500/20000 train_loss:2.0547 train_time:392149ms step_avg:52.29ms +step:7550/20000 train_loss:2.0165 train_time:394738ms step_avg:52.28ms +step:7600/20000 train_loss:1.9748 train_time:397328ms step_avg:52.28ms +step:7650/20000 train_loss:2.0099 train_time:400017ms step_avg:52.29ms +step:7700/20000 train_loss:2.0532 train_time:402608ms step_avg:52.29ms +step:7750/20000 train_loss:2.0572 train_time:405196ms step_avg:52.28ms +step:7800/20000 train_loss:1.9216 train_time:407789ms step_avg:52.28ms +step:7850/20000 train_loss:2.0621 train_time:410383ms step_avg:52.28ms +step:7900/20000 train_loss:2.1483 train_time:413081ms step_avg:52.29ms +step:7950/20000 train_loss:1.9810 train_time:415678ms step_avg:52.29ms +step:8000/20000 train_loss:2.0355 train_time:418271ms step_avg:52.28ms +step:8000/20000 val_loss:2.0678 val_bpb:1.2246 train_time:418298ms step_avg:52.29ms +step:8050/20000 train_loss:2.0104 train_time:420857ms step_avg:52.28ms +step:8100/20000 train_loss:2.0138 train_time:423449ms step_avg:52.28ms +step:8150/20000 train_loss:2.0172 train_time:426134ms step_avg:52.29ms +step:8200/20000 train_loss:2.0736 train_time:428721ms step_avg:52.28ms +step:8250/20000 train_loss:2.0940 train_time:431308ms step_avg:52.28ms +step:8300/20000 train_loss:2.1151 train_time:433896ms step_avg:52.28ms +step:8350/20000 train_loss:2.0713 train_time:436485ms step_avg:52.27ms +step:8400/20000 train_loss:2.0671 train_time:439177ms step_avg:52.28ms +step:8450/20000 train_loss:2.0876 train_time:441765ms step_avg:52.28ms +step:8500/20000 train_loss:1.9228 train_time:444353ms step_avg:52.28ms +step:8550/20000 train_loss:2.0541 train_time:446943ms step_avg:52.27ms +step:8600/20000 train_loss:2.0610 train_time:449531ms step_avg:52.27ms +step:8650/20000 train_loss:2.1088 train_time:452222ms step_avg:52.28ms +step:8700/20000 train_loss:1.9899 train_time:454818ms step_avg:52.28ms +step:8750/20000 train_loss:2.1527 train_time:457407ms step_avg:52.28ms +step:8800/20000 train_loss:2.0140 train_time:459998ms step_avg:52.27ms +step:8850/20000 train_loss:2.1463 train_time:462591ms step_avg:52.27ms +step:8900/20000 train_loss:2.0604 train_time:465180ms step_avg:52.27ms +step:8950/20000 train_loss:1.9942 train_time:467862ms step_avg:52.28ms +step:9000/20000 train_loss:1.9742 train_time:470453ms step_avg:52.27ms +step:9000/20000 val_loss:2.0529 val_bpb:1.2158 train_time:470487ms step_avg:52.28ms +step:9050/20000 train_loss:2.0073 train_time:473038ms step_avg:52.27ms +step:9100/20000 train_loss:2.1509 train_time:475627ms step_avg:52.27ms +step:9150/20000 train_loss:1.9811 train_time:478216ms step_avg:52.26ms +step:9200/20000 train_loss:2.0541 train_time:480916ms step_avg:52.27ms +step:9250/20000 train_loss:2.0177 train_time:483526ms step_avg:52.27ms +step:9300/20000 train_loss:2.1091 train_time:486114ms step_avg:52.27ms +step:9350/20000 train_loss:2.1775 train_time:488705ms step_avg:52.27ms +step:9400/20000 train_loss:2.0170 train_time:491295ms step_avg:52.27ms +step:9450/20000 train_loss:2.0597 train_time:493972ms step_avg:52.27ms +step:9500/20000 train_loss:2.0508 train_time:496558ms step_avg:52.27ms +step:9550/20000 train_loss:2.0011 train_time:499146ms step_avg:52.27ms +step:9600/20000 train_loss:1.9192 train_time:501732ms step_avg:52.26ms +step:9650/20000 train_loss:2.0606 train_time:504320ms step_avg:52.26ms +step:9700/20000 train_loss:2.1116 train_time:506988ms step_avg:52.27ms +step:9750/20000 train_loss:2.0981 train_time:509574ms step_avg:52.26ms +step:9800/20000 train_loss:1.8778 train_time:512165ms step_avg:52.26ms +step:9850/20000 train_loss:1.9928 train_time:514755ms step_avg:52.26ms +step:9900/20000 train_loss:2.1629 train_time:517347ms step_avg:52.26ms +step:9950/20000 train_loss:2.0106 train_time:520035ms step_avg:52.26ms +step:10000/20000 train_loss:2.2900 train_time:522641ms step_avg:52.26ms +step:10000/20000 val_loss:2.0325 val_bpb:1.2038 train_time:522672ms step_avg:52.27ms +step:10050/20000 train_loss:2.0670 train_time:525233ms step_avg:52.26ms +step:10100/20000 train_loss:2.0256 train_time:527817ms step_avg:52.26ms +step:10150/20000 train_loss:1.8322 train_time:530407ms step_avg:52.26ms +step:10200/20000 train_loss:1.9949 train_time:533099ms step_avg:52.26ms +step:10250/20000 train_loss:2.0141 train_time:535689ms step_avg:52.26ms +step:10300/20000 train_loss:2.2751 train_time:538284ms step_avg:52.26ms +step:10350/20000 train_loss:2.0744 train_time:540879ms step_avg:52.26ms +step:10400/20000 train_loss:2.1243 train_time:543470ms step_avg:52.26ms +step:10450/20000 train_loss:1.9712 train_time:546158ms step_avg:52.26ms +step:10500/20000 train_loss:2.0482 train_time:548753ms step_avg:52.26ms +step:10550/20000 train_loss:2.0679 train_time:551348ms step_avg:52.26ms +step:10600/20000 train_loss:1.9611 train_time:553938ms step_avg:52.26ms +step:10650/20000 train_loss:2.0545 train_time:556531ms step_avg:52.26ms +step:10700/20000 train_loss:2.1526 train_time:559228ms step_avg:52.26ms +step:10750/20000 train_loss:2.0829 train_time:561829ms step_avg:52.26ms +step:10800/20000 train_loss:2.0368 train_time:564426ms step_avg:52.26ms +step:10850/20000 train_loss:2.1171 train_time:567018ms step_avg:52.26ms +step:10900/20000 train_loss:2.1192 train_time:569611ms step_avg:52.26ms +step:10950/20000 train_loss:2.0715 train_time:572309ms step_avg:52.27ms +step:11000/20000 train_loss:1.9841 train_time:574902ms step_avg:52.26ms +step:11000/20000 val_loss:2.0135 val_bpb:1.1925 train_time:574930ms step_avg:52.27ms +step:11050/20000 train_loss:2.1428 train_time:577501ms step_avg:52.26ms +step:11100/20000 train_loss:1.9385 train_time:580093ms step_avg:52.26ms +step:11150/20000 train_loss:1.9739 train_time:582688ms step_avg:52.26ms +step:11200/20000 train_loss:2.0283 train_time:585381ms step_avg:52.27ms +step:11250/20000 train_loss:1.9992 train_time:587975ms step_avg:52.26ms +step:11300/20000 train_loss:1.9374 train_time:590573ms step_avg:52.26ms +step:11350/20000 train_loss:1.9705 train_time:593170ms step_avg:52.26ms +step:11400/20000 train_loss:2.0143 train_time:595781ms step_avg:52.26ms +step:11450/20000 train_loss:2.1952 train_time:598473ms step_avg:52.27ms +step:11478/20000 val_loss:2.0071 val_bpb:1.1887 train_time:599949ms step_avg:52.27ms +stopping_early: wallclock_cap train_time:599949ms step:11478/20000 +peak memory allocated: 7655 MiB reserved: 8156 MiB +Serialized model: 67224983 bytes +Code size: 54433 bytes +Total submission size: 67279416 bytes +Serialized model int8+zlib: 15824445 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15878878 bytes +final_int8_zlib_roundtrip val_loss:2.0142 val_bpb:1.1929 eval_time:2123ms +final_int8_zlib_roundtrip_exact val_loss:2.01418632 val_bpb:1.19291459 +final_sliding_window_eval stride:64 val_loss:1.9937 val_bpb:1.1808 eval_time:278940ms +final_sliding_window_eval_exact stride:64 val_loss:1.99374987 val_bpb:1.18081285 diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py new file mode 100644 index 0000000000..421ce35f1f --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py @@ -0,0 +1,1274 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file From f85c8379c6b210df2dbaecf9083f7b60adeecb85 Mon Sep 17 00:00:00 2001 From: aquariouseworkman Date: Thu, 19 Mar 2026 07:06:07 -0400 Subject: [PATCH 2/3] Record: Mixed Quant (int6+int8) + Sliding Window, val_bpb=1.1630 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Submission: Mixed Quantization (int6 blocks + int8 embeddings) + Sliding Window Eval **val_bpb: 1.1630** | **Total size: 15,353,490 bytes** (under 16MB) Four orthogonal improvements over the naive baseline: 1. **Wider MLP (MLP_MULT=3)** — 2x→3x expansion (hidden=1536), enabled by aggressive quantization 2. **Mixed-precision quantization** — int6 per-row (31 levels) on STE-protected block weights, int8 per-row (127 levels) on the token embedding which lacks STE fake-quant. Reduces quant penalty from +0.048 to +0.0015 BPB. 3. **Optimized throughput** — seq_len=1024 + batch=524K tokens for 48.4ms/step, ~6.5B total tokens in 10 minutes 4. **Sliding window eval (stride=64)** — each scored token gets 960 tokens of context, ~0.034 BPB improvement, zero artifact cost ### Run command ```bash RUN_ID=v2_int6_qat_mlp3 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=2000 TRAIN_LOG_EVERY=200 \ torchrun --standalone --nproc_per_node=8 train_gpt.py ``` ### Key metrics | Metric | Value | |--------|-------| | Steps (10 min cap) | 12,395 | | int6/int8 sliding val_bpb | **1.1630** | | Quantization penalty | +0.0015 BPB | | Artifact size | 15,353,490 bytes | --- .../README.md | 104 + .../submission.json | 11 + .../train.log | 4464 +++++++++++++++++ .../train_gpt.py | 100 +- .../2026-03-19_Seq4096SlidingWindow/README.md | 76 - .../submission.json | 11 - .../2026-03-19_Seq4096SlidingWindow/train.log | 1627 ------ 7 files changed, 4654 insertions(+), 1739 deletions(-) create mode 100644 records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md create mode 100644 records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json create mode 100644 records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log rename records/track_10min_16mb/{2026-03-19_Seq4096SlidingWindow => 2026-03-19_MixedQuant_Int6Int8_SlidingWindow}/train_gpt.py (93%) delete mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md delete mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json delete mode 100644 records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md new file mode 100644 index 0000000000..ea0b1d8d34 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md @@ -0,0 +1,104 @@ +Four orthogonal improvements over the naive baseline, each contributing independently to the final score. + +### Changes from Baseline + +**1. Wider MLP (MLP\_MULT=3)** + +The baseline uses a 2x MLP expansion (hidden=1024). We widen to 3x (hidden=1536), increasing total parameters from 17.1M to 21.8M. The wider MLP is enabled by the int6 quantization scheme below which keeps the artifact under 16MB. + +**2. Mixed-Precision Post-Training Quantization** + +Key insight: during training, all `CastedLinear` weights get **fake int6 quantization via Straight-Through Estimator (STE)** — the forward pass uses quantized weights while gradients flow through the originals. This teaches the model weight distributions that survive int6 (31-level) quantization. However, the token embedding (`tok\_emb.weight`) is a plain `nn.Embedding` that **never sees fake quantization during training**. + +Previous approaches applied uniform int6 to all 2D tensors, causing a +0.048 BPB quantization penalty dominated by embedding degradation. Our mixed scheme: + +* **int6 per-row** (31 levels) on all 2D block weights (attention projections, MLP layers) — these have STE protection +* **int8 per-row** (127 levels) on the token embedding — no STE protection, needs gentler quantization +* Small/control tensors pass through as fp16/fp32 + +This reduces the quantization penalty from +0.048 to +0.0015 BPB — a 32x improvement. The int6 values are stored in int8 containers; zlib-9 compresses the zero high bits efficiently. + +**3. Optimized Training Configuration** + +* `TRAIN\_SEQ\_LEN=1024` (down from 4096): Attention is O(N²) in sequence length. Shorter sequences = faster steps (48.4ms vs 55.5ms) = more total training in the 10-minute window. The 512-dim model cannot meaningfully exploit 4K context. +* `TRAIN\_BATCH\_TOKENS=524,288` (up from 393,216): Better GPU saturation at seq\_len=1024, \~33% more tokens per step. +* Result: 12,395 steps × 524K tokens = \~6.50B total tokens (vs \~4.25B with the old config). + +**4. Sliding Window Evaluation (stride=64)** + +Instead of non-overlapping evaluation where early tokens in each chunk get minimal context, we use overlapping windows advanced by 64 tokens. Each window runs the full 1024-token forward pass, but only the last 64 tokens are scored. Every scored token gets 960 tokens of preceding context. + +Sliding window eval improves val\_bpb by \~0.034 with zero artifact cost. stride=64 gives more context per token than stride=256 (960 vs 768), at the cost of longer eval time (\~73s vs \~18s). + +### Configuration + +``` +MLP\_MULT=3 +NUM\_LAYERS=9 +MODEL\_DIM=512 +NUM\_HEADS=8 +NUM\_KV\_HEADS=4 +VOCAB\_SIZE=1024 +TRAIN\_SEQ\_LEN=1024 +TRAIN\_BATCH\_TOKENS=524288 +TIE\_EMBEDDINGS=1 +EVAL\_STRIDE=64 +``` + +Optimizer settings (tuned via env vars, no code changes from baseline optimizer structure): + +``` +MATRIX\_LR=0.020 +SCALAR\_LR=0.020 +TIED\_EMBED\_LR=0.030 +MUON\_MOMENTUM=0.99 +MUON\_MOMENTUM\_WARMUP\_STEPS=1500 +MUON\_MOMENTUM\_WARMUP\_START=0.92 +WARMDOWN\_ITERS=3000 +``` + +### Run Command + +```bash +RUN\_ID=v2\_int6\_qat\_mlp3 \\ +MAX\_WALLCLOCK\_SECONDS=600 \\ +VAL\_LOSS\_EVERY=2000 \\ +TRAIN\_LOG\_EVERY=200 \\ +torchrun --standalone --nproc\_per\_node=8 train\_gpt.py +``` + +### Key Metrics + +* Training stopped at **12,395/20,000** steps due to 10-minute wallclock cap +* Step time: **48.41ms** average on 8xH100 SXM +* Total train tokens: \~6,499,880,000 (12,395 steps × 524,288 tokens/step) +* Peak memory: **11,251 MiB** allocated per GPU + +|Metric|Value| +|-|-| +|Pre-quant val\_bpb (last step)|1.1950| +|int6/int8 mixed roundtrip val\_bpb (standard)|1.1965| +|**int6/int8 mixed roundtrip val\_bpb (sliding, stride=64)**|**1.1630**| +|Quantization penalty (standard eval)|+0.0015 BPB| +|Sliding window eval time|72.6s| +|Compressed artifact (int6+zlib-9)|15,296,720 bytes| +|Code size|56,770 bytes| +|**Total submission size**|**15,353,490 bytes**| + +### Improvement Breakdown + +|Component|val\_bpb|Improvement vs baseline| +|-|-|-| +|Naive baseline (int8, standard eval)|1.2244|—| +|+ Wider MLP 3x + seq1024 + 524K batch|1.1950|-0.0294| +|+ Mixed quant (int6 blocks, int8 embed)|1.1965|+0.0015 penalty| +|+ Sliding window stride=64|**1.1630**|-0.0335 additional| +|**Total improvement**||**-0.0614**| + +### Included Files + +* `train\_gpt.py` — full training + mixed quantization + evaluation script +* `train.log` — complete training log from the 8xH100 run +* `submission.json` — leaderboard metadata +* `README.md` — this file + diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json new file mode 100644 index 0000000000..25922290cf --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json @@ -0,0 +1,11 @@ +{ + "author": "aquariouseworkman", + "github_id": "aquariouseworkman", + "name": "Mixed Quant (int6 blocks + int8 embeddings) + Sliding Window Eval, val_bpb=1.1630", + "blurb": "3x MLP expansion with mixed-precision quantization: int6 per-row (31 levels) on STE-protected block weights, int8 per-row (127 levels) on embedding, zlib-9 compression, sliding window evaluation at stride=64.", + "date": "2026-03-19T10:30:00Z", + "val_loss": 1.96369923, + "val_bpb": 1.16301431, + "bytes_total": 15353490, + "bytes_code": 56770 +} diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log new file mode 100644 index 0000000000..cc4f1477f4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log @@ -0,0 +1,4464 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Int6 per-row quantization: [-31, 31] range stored in int8 containers. + # The unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT6_CLIP_Q = 0.9999984 + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 19 09:44:13 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 35C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | +| N/A 29C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 2072 C /usr/local/bin/python 1512MiB | +| 1 N/A N/A 2073 C /usr/local/bin/python 1512MiB | +| 2 N/A N/A 2074 C /usr/local/bin/python 1512MiB | +| 3 N/A N/A 2075 C /usr/local/bin/python 1512MiB | +| 4 N/A N/A 2076 C /usr/local/bin/python 1512MiB | +| 5 N/A N/A 2077 C /usr/local/bin/python 1512MiB | +| 6 N/A N/A 2078 C /usr/local/bin/python 1512MiB | +| 7 N/A N/A 2079 C /usr/local/bin/python 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:20 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21778504 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9373 train_time:59ms step_avg:58.72ms +step:2/20000 train_loss:12.2289 train_time:80ms step_avg:39.82ms +step:3/20000 train_loss:7.5371 train_time:134ms step_avg:44.70ms +step:4/20000 train_loss:6.2662 train_time:189ms step_avg:47.18ms +step:5/20000 train_loss:6.7749 train_time:244ms step_avg:48.71ms +step:6/20000 train_loss:6.8398 train_time:298ms step_avg:49.73ms +step:7/20000 train_loss:6.7133 train_time:353ms step_avg:50.43ms +step:8/20000 train_loss:6.6259 train_time:408ms step_avg:50.98ms +step:9/20000 train_loss:6.2381 train_time:463ms step_avg:51.41ms +step:10/20000 train_loss:6.1987 train_time:517ms step_avg:51.73ms +step:200/20000 train_loss:2.7737 train_time:10932ms step_avg:54.66ms +step:400/20000 train_loss:2.4091 train_time:21963ms step_avg:54.91ms +step:600/20000 train_loss:2.2791 train_time:33022ms step_avg:55.04ms +step:800/20000 train_loss:2.3532 train_time:44111ms step_avg:55.14ms +step:1000/20000 train_loss:2.3124 train_time:55178ms step_avg:55.18ms +step:1200/20000 train_loss:2.3323 train_time:66301ms step_avg:55.25ms +step:1400/20000 train_loss:2.2776 train_time:77425ms step_avg:55.30ms +step:1600/20000 train_loss:2.2437 train_time:88549ms step_avg:55.34ms +step:1800/20000 train_loss:2.0226 train_time:99669ms step_avg:55.37ms +step:2000/20000 train_loss:2.2808 train_time:110737ms step_avg:55.37ms +step:2200/20000 train_loss:2.0338 train_time:121855ms step_avg:55.39ms +step:2400/20000 train_loss:2.1836 train_time:132969ms step_avg:55.40ms +step:2600/20000 train_loss:2.3013 train_time:144085ms step_avg:55.42ms +step:2800/20000 train_loss:2.3646 train_time:155210ms step_avg:55.43ms +step:3000/20000 train_loss:2.1464 train_time:166276ms step_avg:55.43ms +step:3200/20000 train_loss:2.1688 train_time:177394ms step_avg:55.44ms +step:3400/20000 train_loss:2.1724 train_time:188515ms step_avg:55.45ms +step:3600/20000 train_loss:2.0546 train_time:199637ms step_avg:55.45ms +step:3800/20000 train_loss:2.0704 train_time:210709ms step_avg:55.45ms +step:4000/20000 train_loss:1.8991 train_time:221829ms step_avg:55.46ms +step:4200/20000 train_loss:1.9806 train_time:232956ms step_avg:55.47ms +step:4400/20000 train_loss:2.0861 train_time:244084ms step_avg:55.47ms +step:4600/20000 train_loss:2.0158 train_time:255204ms step_avg:55.48ms +step:4800/20000 train_loss:2.0401 train_time:266278ms step_avg:55.47ms +step:5000/20000 train_loss:2.1076 train_time:277494ms step_avg:55.50ms +step:5200/20000 train_loss:2.0759 train_time:288628ms step_avg:55.51ms +step:5400/20000 train_loss:2.1771 train_time:299751ms step_avg:55.51ms +step:5600/20000 train_loss:2.0141 train_time:310878ms step_avg:55.51ms +step:5800/20000 train_loss:2.1045 train_time:321955ms step_avg:55.51ms +step:6000/20000 train_loss:2.0173 train_time:333077ms step_avg:55.51ms +step:6200/20000 train_loss:1.9834 train_time:344195ms step_avg:55.52ms +step:6400/20000 train_loss:2.0574 train_time:355323ms step_avg:55.52ms +step:6600/20000 train_loss:2.0057 train_time:366399ms step_avg:55.51ms +step:6800/20000 train_loss:1.8728 train_time:377522ms step_avg:55.52ms +step:7000/20000 train_loss:2.0247 train_time:388645ms step_avg:55.52ms +step:7200/20000 train_loss:2.0745 train_time:399763ms step_avg:55.52ms +step:7400/20000 train_loss:2.0938 train_time:410884ms step_avg:55.52ms +step:7600/20000 train_loss:1.9440 train_time:421961ms step_avg:55.52ms +step:7800/20000 train_loss:1.8961 train_time:433081ms step_avg:55.52ms +step:8000/20000 train_loss:1.9993 train_time:444199ms step_avg:55.52ms +step:8200/20000 train_loss:2.0314 train_time:455324ms step_avg:55.53ms +step:8400/20000 train_loss:2.0255 train_time:466445ms step_avg:55.53ms +step:8600/20000 train_loss:2.0221 train_time:477511ms step_avg:55.52ms +step:8800/20000 train_loss:1.9800 train_time:488632ms step_avg:55.53ms +step:9000/20000 train_loss:1.9318 train_time:499754ms step_avg:55.53ms +step:9200/20000 train_loss:2.0151 train_time:510878ms step_avg:55.53ms +step:9400/20000 train_loss:1.9677 train_time:521951ms step_avg:55.53ms +step:9600/20000 train_loss:1.8810 train_time:533075ms step_avg:55.53ms +step:9800/20000 train_loss:1.8309 train_time:544198ms step_avg:55.53ms +step:10000/20000 train_loss:2.2471 train_time:555333ms step_avg:55.53ms +step:10200/20000 train_loss:1.9500 train_time:566455ms step_avg:55.53ms +step:10400/20000 train_loss:2.0805 train_time:577533ms step_avg:55.53ms +step:10600/20000 train_loss:1.9196 train_time:588662ms step_avg:55.53ms +step:10800/20000 train_loss:1.9962 train_time:599781ms step_avg:55.54ms +step:10804/20000 val_loss:1.9787 val_bpb:1.1719 train_time:600037ms step_avg:55.54ms +stopping_early: wallclock_cap train_time:600037ms step:10804/20000 +peak memory allocated: 8517 MiB reserved: 9032 MiB +Serialized model: 86099351 bytes +Code size: 55931 bytes +Total submission size: 86155282 bytes +Serialized model int6+zlib-9: 15160388 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) +Total submission size int6+zlib-9: 15216319 bytes +final_int6_roundtrip val_loss:2.0436 val_bpb:1.2103 eval_time:2262ms +final_int6_roundtrip_exact val_loss:2.04356694 val_bpb:1.21031545 +final_sliding_window_eval stride:64 val_loss:2.0222 val_bpb:1.1976 eval_time:304360ms +final_sliding_window_eval_exact stride:64 val_loss:2.02215627 val_bpb:1.19763674 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Int6 per-row quantization: [-31, 31] range stored in int8 containers. + # The unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT6_CLIP_Q = 0.9999984 + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 19 10:08:37 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 36375 C /usr/local/bin/python 1512MiB | +| 1 N/A N/A 36376 C /usr/local/bin/python 1512MiB | +| 2 N/A N/A 36377 C /usr/local/bin/python 1512MiB | +| 3 N/A N/A 36378 C /usr/local/bin/python 1512MiB | +| 4 N/A N/A 36379 C /usr/local/bin/python 1512MiB | +| 5 N/A N/A 36380 C /usr/local/bin/python 1512MiB | +| 6 N/A N/A 36381 C /usr/local/bin/python 1512MiB | +| 7 N/A N/A 36382 C /usr/local/bin/python 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:20 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21778504 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9378 val_bpb:4.1090 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9364 train_time:25ms step_avg:25.10ms +step:2/20000 train_loss:12.3366 train_time:70ms step_avg:35.22ms +step:3/20000 train_loss:7.4143 train_time:118ms step_avg:39.45ms +step:4/20000 train_loss:6.3474 train_time:165ms step_avg:41.32ms +step:5/20000 train_loss:6.7814 train_time:213ms step_avg:42.66ms +step:6/20000 train_loss:7.4895 train_time:261ms step_avg:43.54ms +step:7/20000 train_loss:6.7486 train_time:309ms step_avg:44.11ms +step:8/20000 train_loss:6.4461 train_time:356ms step_avg:44.54ms +step:9/20000 train_loss:6.3008 train_time:404ms step_avg:44.87ms +step:10/20000 train_loss:6.1806 train_time:451ms step_avg:45.14ms +step:200/20000 train_loss:2.7609 train_time:9576ms step_avg:47.88ms +step:400/20000 train_loss:2.2985 train_time:19184ms step_avg:47.96ms +step:600/20000 train_loss:2.5044 train_time:28808ms step_avg:48.01ms +step:800/20000 train_loss:2.2650 train_time:38451ms step_avg:48.06ms +step:1000/20000 train_loss:2.3522 train_time:48118ms step_avg:48.12ms +step:1200/20000 train_loss:2.3722 train_time:57803ms step_avg:48.17ms +step:1400/20000 train_loss:2.4249 train_time:67502ms step_avg:48.22ms +step:1600/20000 train_loss:2.0953 train_time:77188ms step_avg:48.24ms +step:1800/20000 train_loss:2.1866 train_time:86887ms step_avg:48.27ms +step:2000/20000 train_loss:2.2375 train_time:96583ms step_avg:48.29ms +step:2000/20000 val_loss:2.2181 val_bpb:1.3137 train_time:96613ms step_avg:48.31ms +step:2200/20000 train_loss:2.0498 train_time:106273ms step_avg:48.31ms +step:2400/20000 train_loss:2.1729 train_time:115967ms step_avg:48.32ms +step:2600/20000 train_loss:2.3764 train_time:125661ms step_avg:48.33ms +step:2800/20000 train_loss:2.1998 train_time:135356ms step_avg:48.34ms +step:3000/20000 train_loss:2.1901 train_time:145046ms step_avg:48.35ms +step:3200/20000 train_loss:2.1510 train_time:154732ms step_avg:48.35ms +step:3400/20000 train_loss:2.1192 train_time:164420ms step_avg:48.36ms +step:3600/20000 train_loss:2.0793 train_time:174112ms step_avg:48.36ms +step:3800/20000 train_loss:2.1825 train_time:183806ms step_avg:48.37ms +step:4000/20000 train_loss:2.2339 train_time:193495ms step_avg:48.37ms +step:4000/20000 val_loss:2.1286 val_bpb:1.2607 train_time:193525ms step_avg:48.38ms +step:4200/20000 train_loss:2.1753 train_time:203239ms step_avg:48.39ms +step:4400/20000 train_loss:2.1257 train_time:212931ms step_avg:48.39ms +step:4600/20000 train_loss:2.1593 train_time:222632ms step_avg:48.40ms +step:4800/20000 train_loss:2.1010 train_time:232323ms step_avg:48.40ms +step:5000/20000 train_loss:2.1841 train_time:242009ms step_avg:48.40ms +step:5200/20000 train_loss:2.2419 train_time:251704ms step_avg:48.40ms +step:5400/20000 train_loss:2.1947 train_time:261390ms step_avg:48.41ms +step:5600/20000 train_loss:2.1010 train_time:271090ms step_avg:48.41ms +step:5800/20000 train_loss:2.2822 train_time:280782ms step_avg:48.41ms +step:6000/20000 train_loss:2.0261 train_time:290480ms step_avg:48.41ms +step:6000/20000 val_loss:2.0964 val_bpb:1.2416 train_time:290510ms step_avg:48.42ms +step:6200/20000 train_loss:2.1061 train_time:300162ms step_avg:48.41ms +step:6400/20000 train_loss:1.9307 train_time:309857ms step_avg:48.42ms +step:6600/20000 train_loss:2.0882 train_time:319546ms step_avg:48.42ms +step:6800/20000 train_loss:2.1693 train_time:329233ms step_avg:48.42ms +step:7000/20000 train_loss:2.0145 train_time:338923ms step_avg:48.42ms +step:7200/20000 train_loss:2.0029 train_time:348614ms step_avg:48.42ms +step:7400/20000 train_loss:1.9361 train_time:358310ms step_avg:48.42ms +step:7600/20000 train_loss:2.0350 train_time:368000ms step_avg:48.42ms +step:7800/20000 train_loss:2.0813 train_time:377688ms step_avg:48.42ms +step:8000/20000 train_loss:2.0144 train_time:387378ms step_avg:48.42ms +step:8000/20000 val_loss:2.0773 val_bpb:1.2303 train_time:387408ms step_avg:48.43ms +step:8200/20000 train_loss:2.1489 train_time:397067ms step_avg:48.42ms +step:8400/20000 train_loss:2.1424 train_time:406801ms step_avg:48.43ms +step:8600/20000 train_loss:2.1657 train_time:416481ms step_avg:48.43ms +step:8800/20000 train_loss:2.0297 train_time:426178ms step_avg:48.43ms +step:9000/20000 train_loss:2.0599 train_time:435859ms step_avg:48.43ms +step:9200/20000 train_loss:2.1330 train_time:445542ms step_avg:48.43ms +step:9400/20000 train_loss:1.9739 train_time:455237ms step_avg:48.43ms +step:9600/20000 train_loss:2.0265 train_time:464918ms step_avg:48.43ms +step:9800/20000 train_loss:2.0734 train_time:474609ms step_avg:48.43ms +step:10000/20000 train_loss:2.1375 train_time:484302ms step_avg:48.43ms +step:10000/20000 val_loss:2.0568 val_bpb:1.2182 train_time:484331ms step_avg:48.43ms +step:10200/20000 train_loss:1.9916 train_time:493993ms step_avg:48.43ms +step:10400/20000 train_loss:2.2391 train_time:503676ms step_avg:48.43ms +step:10600/20000 train_loss:1.8816 train_time:513372ms step_avg:48.43ms +step:10800/20000 train_loss:2.0403 train_time:523069ms step_avg:48.43ms +step:11000/20000 train_loss:2.1571 train_time:532770ms step_avg:48.43ms +step:11200/20000 train_loss:2.1207 train_time:542462ms step_avg:48.43ms +step:11400/20000 train_loss:2.1568 train_time:552158ms step_avg:48.43ms +step:11600/20000 train_loss:2.0102 train_time:561851ms step_avg:48.44ms +step:11800/20000 train_loss:2.0893 train_time:571552ms step_avg:48.44ms +step:12000/20000 train_loss:1.9977 train_time:581254ms step_avg:48.44ms +step:12000/20000 val_loss:2.0216 val_bpb:1.1973 train_time:581283ms step_avg:48.44ms +step:12200/20000 train_loss:2.1668 train_time:590989ms step_avg:48.44ms +step:12386/20000 val_loss:2.0178 val_bpb:1.1951 train_time:600039ms step_avg:48.44ms +stopping_early: wallclock_cap train_time:600039ms step:12386/20000 +peak memory allocated: 11254 MiB reserved: 11520 MiB +Serialized model: 86099351 bytes +Code size: 55932 bytes +Total submission size: 86155283 bytes +Serialized model int6+zlib-9: 15164668 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) +Total submission size int6+zlib-9: 15220600 bytes +final_int6_roundtrip val_loss:2.0991 val_bpb:1.2432 eval_time:1548ms +final_int6_roundtrip_exact val_loss:2.09906746 val_bpb:1.24318598 +final_sliding_window_eval stride:256 val_loss:2.0403 val_bpb:1.2084 eval_time:18256ms +final_sliding_window_eval_exact stride:256 val_loss:2.04033873 val_bpb:1.20840424 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT8_QUANT_RANGE = 127 # int8: [-127, 127] +INT6_CLIP_Q = 0.9999984 + +# Tensors matching these patterns get int8 (127 levels) instead of int6 (31 levels) +# during post-training quantization. This is for weights that DON'T have fake-quant +# STE protection during training (e.g. nn.Embedding), so they need gentler quantization. +INT8_FULL_RANGE_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_FULL_RANGE_PATTERNS", + "tok_emb", + ).split(",") + if pattern +) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, quant_range: int = INT6_QUANT_RANGE) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Per-row quantization: [-quant_range, quant_range] stored in int8 containers. + # For int6 (range=31), the unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(quant_range)).clamp_min(1.0 / float(quant_range)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -quant_range, quant_range).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Use int8 (127 levels) for tensors without STE fake-quant protection + # (e.g. tok_emb), int6 (31 levels) for block weights that trained with STE. + use_int8 = any(pattern in name for pattern in INT8_FULL_RANGE_PATTERNS) + qr = INT8_QUANT_RANGE if use_int8 else INT6_QUANT_RANGE + q, s = quantize_float_tensor(t, quant_range=qr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 19 10:32:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | +| N/A 29C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 57118 C /usr/local/bin/python 1512MiB | +| 1 N/A N/A 57119 C /usr/local/bin/python 1512MiB | +| 2 N/A N/A 57120 C /usr/local/bin/python 1512MiB | +| 3 N/A N/A 57121 C /usr/local/bin/python 1512MiB | +| 4 N/A N/A 57122 C /usr/local/bin/python 1512MiB | +| 5 N/A N/A 57123 C /usr/local/bin/python 1512MiB | +| 6 N/A N/A 57124 C /usr/local/bin/python 1512MiB | +| 7 N/A N/A 57125 C /usr/local/bin/python 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:20 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21778504 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9378 val_bpb:4.1090 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9364 train_time:25ms step_avg:25.39ms +step:2/20000 train_loss:12.3366 train_time:69ms step_avg:34.49ms +step:3/20000 train_loss:7.4142 train_time:116ms step_avg:38.83ms +step:4/20000 train_loss:6.3475 train_time:164ms step_avg:40.94ms +step:5/20000 train_loss:6.7813 train_time:211ms step_avg:42.24ms +step:6/20000 train_loss:7.4896 train_time:259ms step_avg:43.11ms +step:7/20000 train_loss:6.7491 train_time:307ms step_avg:43.79ms +step:8/20000 train_loss:6.4461 train_time:354ms step_avg:44.27ms +step:9/20000 train_loss:6.3003 train_time:402ms step_avg:44.63ms +step:10/20000 train_loss:6.1808 train_time:449ms step_avg:44.92ms +step:200/20000 train_loss:2.7644 train_time:9564ms step_avg:47.82ms +step:400/20000 train_loss:2.3035 train_time:19162ms step_avg:47.91ms +step:600/20000 train_loss:2.5030 train_time:28769ms step_avg:47.95ms +step:800/20000 train_loss:2.2665 train_time:38403ms step_avg:48.00ms +step:1000/20000 train_loss:2.3514 train_time:48062ms step_avg:48.06ms +step:1200/20000 train_loss:2.3714 train_time:57723ms step_avg:48.10ms +step:1400/20000 train_loss:2.4243 train_time:67404ms step_avg:48.15ms +step:1600/20000 train_loss:2.0992 train_time:77085ms step_avg:48.18ms +step:1800/20000 train_loss:2.1921 train_time:86767ms step_avg:48.20ms +step:2000/20000 train_loss:2.2339 train_time:96459ms step_avg:48.23ms +step:2000/20000 val_loss:2.2180 val_bpb:1.3136 train_time:96488ms step_avg:48.24ms +step:2200/20000 train_loss:2.0467 train_time:106139ms step_avg:48.24ms +step:2400/20000 train_loss:2.1760 train_time:115822ms step_avg:48.26ms +step:2600/20000 train_loss:2.3846 train_time:125507ms step_avg:48.27ms +step:2800/20000 train_loss:2.2012 train_time:135190ms step_avg:48.28ms +step:3000/20000 train_loss:2.1922 train_time:144871ms step_avg:48.29ms +step:3200/20000 train_loss:2.1540 train_time:154556ms step_avg:48.30ms +step:3400/20000 train_loss:2.1266 train_time:164240ms step_avg:48.31ms +step:3600/20000 train_loss:2.0762 train_time:173935ms step_avg:48.32ms +step:3800/20000 train_loss:2.1830 train_time:183612ms step_avg:48.32ms +step:4000/20000 train_loss:2.2349 train_time:193304ms step_avg:48.33ms +step:4000/20000 val_loss:2.1282 val_bpb:1.2604 train_time:193334ms step_avg:48.33ms +step:4200/20000 train_loss:2.1801 train_time:203036ms step_avg:48.34ms +step:4400/20000 train_loss:2.1307 train_time:212724ms step_avg:48.35ms +step:4600/20000 train_loss:2.1653 train_time:222417ms step_avg:48.35ms +step:4800/20000 train_loss:2.0977 train_time:232099ms step_avg:48.35ms +step:5000/20000 train_loss:2.1822 train_time:241786ms step_avg:48.36ms +step:5200/20000 train_loss:2.2444 train_time:251474ms step_avg:48.36ms +step:5400/20000 train_loss:2.1939 train_time:261163ms step_avg:48.36ms +step:5600/20000 train_loss:2.0976 train_time:270860ms step_avg:48.37ms +step:5800/20000 train_loss:2.2809 train_time:280540ms step_avg:48.37ms +step:6000/20000 train_loss:2.0262 train_time:290222ms step_avg:48.37ms +step:6000/20000 val_loss:2.0966 val_bpb:1.2417 train_time:290251ms step_avg:48.38ms +step:6200/20000 train_loss:2.1005 train_time:299909ms step_avg:48.37ms +step:6400/20000 train_loss:1.9273 train_time:309598ms step_avg:48.37ms +step:6600/20000 train_loss:2.0822 train_time:319283ms step_avg:48.38ms +step:6800/20000 train_loss:2.1677 train_time:328963ms step_avg:48.38ms +step:7000/20000 train_loss:2.0135 train_time:338650ms step_avg:48.38ms +step:7200/20000 train_loss:2.0018 train_time:348331ms step_avg:48.38ms +step:7400/20000 train_loss:1.9328 train_time:358015ms step_avg:48.38ms +step:7600/20000 train_loss:2.0359 train_time:367703ms step_avg:48.38ms +step:7800/20000 train_loss:2.0807 train_time:377384ms step_avg:48.38ms +step:8000/20000 train_loss:2.0150 train_time:387064ms step_avg:48.38ms +step:8000/20000 val_loss:2.0768 val_bpb:1.2300 train_time:387094ms step_avg:48.39ms +step:8200/20000 train_loss:2.1487 train_time:396747ms step_avg:48.38ms +step:8400/20000 train_loss:2.1440 train_time:406469ms step_avg:48.39ms +step:8600/20000 train_loss:2.1641 train_time:416138ms step_avg:48.39ms +step:8800/20000 train_loss:2.0296 train_time:425826ms step_avg:48.39ms +step:9000/20000 train_loss:2.0574 train_time:435503ms step_avg:48.39ms +step:9200/20000 train_loss:2.1356 train_time:445268ms step_avg:48.40ms +step:9400/20000 train_loss:1.9742 train_time:454945ms step_avg:48.40ms +step:9600/20000 train_loss:2.0251 train_time:464630ms step_avg:48.40ms +step:9800/20000 train_loss:2.0776 train_time:474313ms step_avg:48.40ms +step:10000/20000 train_loss:2.1372 train_time:483997ms step_avg:48.40ms +step:10000/20000 val_loss:2.0570 val_bpb:1.2183 train_time:484027ms step_avg:48.40ms +step:10200/20000 train_loss:1.9914 train_time:493676ms step_avg:48.40ms +step:10400/20000 train_loss:2.2388 train_time:503361ms step_avg:48.40ms +step:10600/20000 train_loss:1.8790 train_time:513046ms step_avg:48.40ms +step:10800/20000 train_loss:2.0378 train_time:522738ms step_avg:48.40ms +step:11000/20000 train_loss:2.1572 train_time:532426ms step_avg:48.40ms +step:11200/20000 train_loss:2.1221 train_time:542116ms step_avg:48.40ms +step:11400/20000 train_loss:2.1492 train_time:551800ms step_avg:48.40ms +step:11600/20000 train_loss:2.0106 train_time:561488ms step_avg:48.40ms +step:11800/20000 train_loss:2.0874 train_time:571185ms step_avg:48.41ms +step:12000/20000 train_loss:1.9972 train_time:580870ms step_avg:48.41ms +step:12000/20000 val_loss:2.0215 val_bpb:1.1973 train_time:580900ms step_avg:48.41ms +step:12200/20000 train_loss:2.1699 train_time:590570ms step_avg:48.41ms +step:12395/20000 val_loss:2.0177 val_bpb:1.1950 train_time:600044ms step_avg:48.41ms +stopping_early: wallclock_cap train_time:600044ms step:12395/20000 +peak memory allocated: 11251 MiB reserved: 11396 MiB +Serialized model: 86099351 bytes +Code size: 56770 bytes +Total submission size: 86156121 bytes +Serialized model int6+zlib-9: 15296720 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) +Total submission size int6+zlib-9: 15353490 bytes +final_int6_roundtrip val_loss:2.0203 val_bpb:1.1965 eval_time:1547ms +final_int6_roundtrip_exact val_loss:2.02025601 val_bpb:1.19650941 +final_sliding_window_eval stride:64 val_loss:1.9637 val_bpb:1.1630 eval_time:72650ms +final_sliding_window_eval_exact stride:64 val_loss:1.96369923 val_bpb:1.16301431 diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py similarity index 93% rename from records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py rename to records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py index 421ce35f1f..284a42bd72 100644 --- a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train_gpt.py +++ b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py @@ -19,6 +19,12 @@ import zlib from pathlib import Path +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + import numpy as np import sentencepiece as spm import torch @@ -47,7 +53,7 @@ class Hyperparameters: # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Sliding window eval: stride controls how many tokens to slide between windows. # Only the last `eval_stride` tokens per window are scored, giving each scored token @@ -58,8 +64,8 @@ class Hyperparameters: iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -69,7 +75,7 @@ class Hyperparameters: num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -403,8 +409,22 @@ def eval_val_sliding_window( INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT8_QUANT_RANGE = 127 # int8: [-127, 127] +INT6_CLIP_Q = 0.9999984 + +# Tensors matching these patterns get int8 (127 levels) instead of int6 (31 levels) +# during post-training quantization. This is for weights that DON'T have fake-quant +# STE protection during training (e.g. nn.Embedding), so they need gentler quantization. +INT8_FULL_RANGE_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_FULL_RANGE_PATTERNS", + "tok_emb", + ).split(",") + if pattern +) def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) @@ -417,23 +437,23 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: +def quantize_float_tensor(t: Tensor, quant_range: int = INT6_QUANT_RANGE) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. + # Per-row quantization: [-quant_range, quant_range] stored in int8 containers. + # For int6 (range=31), the unused high bits compress extremely well with zstd. clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) ) clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + scale = (clip_abs / float(quant_range)).clamp_min(1.0 / float(quant_range)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -quant_range, quant_range).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale @@ -476,7 +496,11 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): continue stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) + # Use int8 (127 levels) for tensors without STE fake-quant protection + # (e.g. tok_emb), int6 (31 levels) for block weights that trained with STE. + use_int8 = any(pattern in name for pattern in INT8_FULL_RANGE_PATTERNS) + qr = INT8_QUANT_RANGE if use_int8 else INT6_QUANT_RANGE + q, s = quantize_float_tensor(t, quant_range=qr) if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} quantized[name] = q @@ -607,9 +631,22 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: @@ -1198,25 +1235,38 @@ def lr_mul(step: int, elapsed_ms: float) -> float: quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" quant_raw_bytes = len(quant_raw) if master_process: - with open("final_model.int8.ptz", "wb") as f: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() - with open("final_model.int8.ptz", "rb") as f: + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() @@ -1234,10 +1284,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) torch.cuda.synchronize() log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") # Sliding window eval: gives each scored token (seq_len - stride) context. if args.eval_stride > 0: diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md deleted file mode 100644 index a51057cd53..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/README.md +++ /dev/null @@ -1,76 +0,0 @@ -# Training Opt Seq4096 + Sliding Window Eval (stride=64) - -This record combines two independent improvements over the naive baseline: - -1. **Longer training context (seq_len=4096) with aggressively tuned Muon optimizer** -2. **Sliding window evaluation (stride=64) for near-full context on every scored token** - -## Configuration - -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Sequence length: `TRAIN_SEQ_LEN=4096` -- Batching: `TRAIN_BATCH_TOKENS=393216` (3/4 batch for more optimizer updates per second) -- Learning rates: `TIED_EMBED_LR=0.030 MATRIX_LR=0.020 SCALAR_LR=0.020` -- Muon optimizer: `MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_STEPS=1500 MUON_MOMENTUM_WARMUP_START=0.92` -- Schedule: `WARMDOWN_ITERS=3000` -- Eval: `EVAL_STRIDE=64` (sliding window, each scored token gets 4032 tokens of context) - -## Command - -```bash -RUN_ID=combined_seq4096_sw64 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -TRAIN_LOG_EVERY=50 \ -VAL_LOSS_EVERY=1000 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Key Metrics - -- Timed training stopped at `11478/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0071`, `val_bpb:1.1887` -- Post-quant roundtrip eval (standard): `val_loss:2.0142`, `val_bpb:1.1929` -- **Post-quant sliding window eval: `val_loss:1.9937`, `val_bpb:1.1808`** -- Exact printed metric: `final_sliding_window_eval_exact stride:64 val_loss:1.99374987 val_bpb:1.18081285` -- Train time: `599949ms` (`step_avg:52.27ms`) -- Sliding window eval time: `278940ms` (under 10-minute eval budget) -- Peak memory: `7655 MiB allocated`, `8156 MiB reserved` -- Serialized model int8+zlib: `15824445 bytes` -- Code size: `54433 bytes` -- Total submission size int8+zlib: `15878878 bytes` - -## Approach - -This submission stacks three orthogonal improvements: - -### 1. Longer training context (seq_len=4096) -Each training sequence sees 4x more context than the 1024-token baseline, giving the autoregressive model much better signal per token. This costs ~52ms/step (vs ~43ms at seq_len=1024), but the quality improvement far outweighs the fewer total steps (11,478 vs ~13,780). - -### 2. Aggressive Muon optimizer tuning -- **Higher momentum (0.99 vs 0.95):** Stronger gradient smoothing for better convergence. -- **Lower learning rates (0.020 vs 0.04):** Dramatically reduces int8 quantization loss (0.004 BPB quant penalty vs 0.007+ at default LR) while maintaining similar pre-quant quality. -- **3/4 batch (393K vs 524K tokens):** More optimizer updates per wallclock second. -- **Extended momentum warmup (1500 steps from 0.92):** Prevents early instability with the higher momentum. -- **Longer warmdown (3000 steps):** Proportionally longer LR decay for the ~11,500-step run. - -### 3. Sliding window evaluation (stride=64) -Standard evaluation chops the validation set into non-overlapping 4096-token blocks, meaning the first tokens in each block get minimal context. Sliding window evaluation shifts the window by only 64 tokens at a time, scoring only the last 64 tokens per window. Each scored token sees 4032 tokens of prior context, dramatically reducing the "cold start" penalty. This improves BPB by ~0.012 over standard post-quant eval (1.1808 vs 1.1929). - -The evaluation uses a `forward_logits()` method that returns logits without computing loss, allowing us to score only the relevant trailing tokens per window. The sliding window runs on 8 GPUs in parallel with batched windows (32 per forward pass) and completes in ~279 seconds. - -## Training Volume - -- Global batch: `393216` tokens/step -- Total train tokens seen: `~4.5B` (11,478 steps × 393,216 tokens) -- Dataset: 20 shards of fineweb10B_sp1024 - -## Included Files - -- `train_gpt.py` (standalone script with all changes baked in) -- `README.md` (this file) -- `submission.json` (leaderboard metadata) -- `train.log` (full training log from the canonical run) diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json deleted file mode 100644 index 9df0d68362..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "aquariouseworkman", - "github_id": "aquariouseworkman", - "name": "Seq4096 Muon Tuned + Sliding Window Eval", - "blurb": "SP-1024 9x512 KV4 trained at seq_len=4096 with tuned Muon optimizer (momentum 0.99, LR 0.020, 3/4 batch, warmdown 3000). Sliding window eval (stride=64) gives each scored token 4032 tokens of context. Post-quant sliding window val_bpb=1.1808.", - "date": "2026-03-19T07:00:00Z", - "val_loss": 1.99374987, - "val_bpb": 1.18081285, - "bytes_total": 15878878, - "bytes_code": 54433 -} diff --git a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log b/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log deleted file mode 100644 index 31ee9659b8..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq4096SlidingWindow/train.log +++ /dev/null @@ -1,1627 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 06:53:44 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | -| N/A 29C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | -| N/A 29C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 29C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | -| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 1923 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 1924 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 1925 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 1926 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 1927 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 1928 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 1929 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 1930 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:20 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.03ms -step:1/20000 train_loss:6.9376 train_time:33ms step_avg:32.90ms -step:2/20000 train_loss:12.1428 train_time:88ms step_avg:43.95ms -step:3/20000 train_loss:7.4539 train_time:147ms step_avg:49.10ms -step:4/20000 train_loss:6.3261 train_time:207ms step_avg:51.67ms -step:5/20000 train_loss:6.9176 train_time:263ms step_avg:52.65ms -step:6/20000 train_loss:6.9130 train_time:325ms step_avg:54.17ms -step:7/20000 train_loss:6.7021 train_time:380ms step_avg:54.27ms -step:8/20000 train_loss:6.6697 train_time:440ms step_avg:54.97ms -step:9/20000 train_loss:6.3152 train_time:497ms step_avg:55.19ms -step:10/20000 train_loss:6.2095 train_time:556ms step_avg:55.56ms -step:50/20000 train_loss:4.1814 train_time:2618ms step_avg:52.36ms -step:100/20000 train_loss:3.3915 train_time:5201ms step_avg:52.01ms -step:150/20000 train_loss:3.0600 train_time:7774ms step_avg:51.83ms -step:200/20000 train_loss:2.7600 train_time:10349ms step_avg:51.75ms -step:250/20000 train_loss:2.6022 train_time:12924ms step_avg:51.70ms -step:300/20000 train_loss:2.6200 train_time:15591ms step_avg:51.97ms -step:350/20000 train_loss:2.6057 train_time:18161ms step_avg:51.89ms -step:400/20000 train_loss:2.4191 train_time:20735ms step_avg:51.84ms -step:450/20000 train_loss:1.9559 train_time:23309ms step_avg:51.80ms -step:500/20000 train_loss:2.4657 train_time:25883ms step_avg:51.77ms -step:550/20000 train_loss:2.4548 train_time:28581ms step_avg:51.97ms -step:600/20000 train_loss:2.2913 train_time:31168ms step_avg:51.95ms -step:650/20000 train_loss:2.3561 train_time:33754ms step_avg:51.93ms -step:700/20000 train_loss:2.3823 train_time:36345ms step_avg:51.92ms -step:750/20000 train_loss:2.3120 train_time:38937ms step_avg:51.92ms -step:800/20000 train_loss:2.3685 train_time:41646ms step_avg:52.06ms -step:850/20000 train_loss:2.3882 train_time:44237ms step_avg:52.04ms -step:900/20000 train_loss:2.1988 train_time:46828ms step_avg:52.03ms -step:950/20000 train_loss:2.2929 train_time:49420ms step_avg:52.02ms -step:1000/20000 train_loss:2.3322 train_time:52014ms step_avg:52.01ms -step:1000/20000 val_loss:2.3041 val_bpb:1.3646 train_time:52038ms step_avg:52.04ms -step:1050/20000 train_loss:2.3307 train_time:54706ms step_avg:52.10ms -step:1100/20000 train_loss:2.3880 train_time:57297ms step_avg:52.09ms -step:1150/20000 train_loss:2.3096 train_time:59889ms step_avg:52.08ms -step:1200/20000 train_loss:2.3520 train_time:62485ms step_avg:52.07ms -step:1250/20000 train_loss:2.2906 train_time:65076ms step_avg:52.06ms -step:1300/20000 train_loss:2.4268 train_time:67768ms step_avg:52.13ms -step:1350/20000 train_loss:2.2162 train_time:70360ms step_avg:52.12ms -step:1400/20000 train_loss:2.2977 train_time:72956ms step_avg:52.11ms -step:1450/20000 train_loss:2.3528 train_time:75550ms step_avg:52.10ms -step:1500/20000 train_loss:2.4259 train_time:78146ms step_avg:52.10ms -step:1550/20000 train_loss:2.3027 train_time:80840ms step_avg:52.15ms -step:1600/20000 train_loss:2.2671 train_time:83436ms step_avg:52.15ms -step:1650/20000 train_loss:2.2047 train_time:86028ms step_avg:52.14ms -step:1700/20000 train_loss:2.2044 train_time:88625ms step_avg:52.13ms -step:1750/20000 train_loss:2.3344 train_time:91217ms step_avg:52.12ms -step:1800/20000 train_loss:2.0434 train_time:93920ms step_avg:52.18ms -step:1850/20000 train_loss:2.2787 train_time:96516ms step_avg:52.17ms -step:1900/20000 train_loss:2.2827 train_time:99110ms step_avg:52.16ms -step:1950/20000 train_loss:2.2277 train_time:101703ms step_avg:52.16ms -step:2000/20000 train_loss:2.3059 train_time:104296ms step_avg:52.15ms -step:2000/20000 val_loss:2.2113 val_bpb:1.3097 train_time:104328ms step_avg:52.16ms -step:2050/20000 train_loss:2.8644 train_time:106995ms step_avg:52.19ms -step:2100/20000 train_loss:2.1551 train_time:109592ms step_avg:52.19ms -step:2150/20000 train_loss:2.2515 train_time:112182ms step_avg:52.18ms -step:2200/20000 train_loss:2.0542 train_time:114772ms step_avg:52.17ms -step:2250/20000 train_loss:2.4315 train_time:117366ms step_avg:52.16ms -step:2300/20000 train_loss:2.1390 train_time:120068ms step_avg:52.20ms -step:2350/20000 train_loss:2.1749 train_time:122661ms step_avg:52.20ms -step:2400/20000 train_loss:2.2065 train_time:125264ms step_avg:52.19ms -step:2450/20000 train_loss:2.3101 train_time:127855ms step_avg:52.19ms -step:2500/20000 train_loss:2.2842 train_time:130447ms step_avg:52.18ms -step:2550/20000 train_loss:2.2999 train_time:133134ms step_avg:52.21ms -step:2600/20000 train_loss:2.3269 train_time:135727ms step_avg:52.20ms -step:2650/20000 train_loss:2.2548 train_time:138321ms step_avg:52.20ms -step:2700/20000 train_loss:2.1140 train_time:140909ms step_avg:52.19ms -step:2750/20000 train_loss:2.1784 train_time:143501ms step_avg:52.18ms -step:2800/20000 train_loss:2.3925 train_time:146200ms step_avg:52.21ms -step:2850/20000 train_loss:2.1333 train_time:148792ms step_avg:52.21ms -step:2900/20000 train_loss:2.1318 train_time:151387ms step_avg:52.20ms -step:2950/20000 train_loss:2.1459 train_time:153977ms step_avg:52.20ms -step:3000/20000 train_loss:2.1693 train_time:156571ms step_avg:52.19ms -step:3000/20000 val_loss:2.1549 val_bpb:1.2763 train_time:156597ms step_avg:52.20ms -step:3050/20000 train_loss:2.0718 train_time:159165ms step_avg:52.19ms -step:3100/20000 train_loss:2.0741 train_time:161860ms step_avg:52.21ms -step:3150/20000 train_loss:2.2964 train_time:164450ms step_avg:52.21ms -step:3200/20000 train_loss:2.1957 train_time:167038ms step_avg:52.20ms -step:3250/20000 train_loss:2.1752 train_time:169631ms step_avg:52.19ms -step:3300/20000 train_loss:2.0950 train_time:172227ms step_avg:52.19ms -step:3350/20000 train_loss:2.2272 train_time:174916ms step_avg:52.21ms -step:3400/20000 train_loss:2.1998 train_time:177509ms step_avg:52.21ms -step:3450/20000 train_loss:2.0994 train_time:180105ms step_avg:52.20ms -step:3500/20000 train_loss:2.1930 train_time:182703ms step_avg:52.20ms -step:3550/20000 train_loss:2.1412 train_time:185300ms step_avg:52.20ms -step:3600/20000 train_loss:2.0727 train_time:188011ms step_avg:52.23ms -step:3650/20000 train_loss:2.1659 train_time:190609ms step_avg:52.22ms -step:3700/20000 train_loss:2.1052 train_time:193207ms step_avg:52.22ms -step:3750/20000 train_loss:2.1134 train_time:195805ms step_avg:52.21ms -step:3800/20000 train_loss:2.0930 train_time:198401ms step_avg:52.21ms -step:3850/20000 train_loss:2.1268 train_time:201119ms step_avg:52.24ms -step:3900/20000 train_loss:2.0784 train_time:203717ms step_avg:52.24ms -step:3950/20000 train_loss:2.1538 train_time:206315ms step_avg:52.23ms -step:4000/20000 train_loss:1.9293 train_time:208915ms step_avg:52.23ms -step:4000/20000 val_loss:2.1208 val_bpb:1.2561 train_time:208940ms step_avg:52.24ms -step:4050/20000 train_loss:2.1241 train_time:211509ms step_avg:52.22ms -step:4100/20000 train_loss:2.1416 train_time:214210ms step_avg:52.25ms -step:4150/20000 train_loss:2.0450 train_time:216807ms step_avg:52.24ms -step:4200/20000 train_loss:2.0048 train_time:219404ms step_avg:52.24ms -step:4250/20000 train_loss:2.1380 train_time:221999ms step_avg:52.23ms -step:4300/20000 train_loss:2.1309 train_time:224595ms step_avg:52.23ms -step:4350/20000 train_loss:2.0464 train_time:227297ms step_avg:52.25ms -step:4400/20000 train_loss:2.1155 train_time:229888ms step_avg:52.25ms -step:4450/20000 train_loss:2.0994 train_time:232479ms step_avg:52.24ms -step:4500/20000 train_loss:2.1875 train_time:235075ms step_avg:52.24ms -step:4550/20000 train_loss:2.1034 train_time:237666ms step_avg:52.23ms -step:4600/20000 train_loss:2.0381 train_time:240368ms step_avg:52.25ms -step:4650/20000 train_loss:2.0448 train_time:242965ms step_avg:52.25ms -step:4700/20000 train_loss:2.1083 train_time:245559ms step_avg:52.25ms -step:4750/20000 train_loss:2.0594 train_time:248155ms step_avg:52.24ms -step:4800/20000 train_loss:2.0705 train_time:250751ms step_avg:52.24ms -step:4850/20000 train_loss:2.1384 train_time:253464ms step_avg:52.26ms -step:4900/20000 train_loss:2.1304 train_time:256063ms step_avg:52.26ms -step:4950/20000 train_loss:2.4399 train_time:258658ms step_avg:52.25ms -step:5000/20000 train_loss:2.1459 train_time:261255ms step_avg:52.25ms -step:5000/20000 val_loss:2.1024 val_bpb:1.2452 train_time:261287ms step_avg:52.26ms -step:5050/20000 train_loss:2.0569 train_time:263854ms step_avg:52.25ms -step:5100/20000 train_loss:2.1764 train_time:266552ms step_avg:52.27ms -step:5150/20000 train_loss:1.9255 train_time:269147ms step_avg:52.26ms -step:5200/20000 train_loss:2.1012 train_time:271751ms step_avg:52.26ms -step:5250/20000 train_loss:2.0996 train_time:274349ms step_avg:52.26ms -step:5300/20000 train_loss:2.0296 train_time:276947ms step_avg:52.25ms -step:5350/20000 train_loss:2.0501 train_time:279653ms step_avg:52.27ms -step:5400/20000 train_loss:2.2050 train_time:282252ms step_avg:52.27ms -step:5450/20000 train_loss:2.0719 train_time:284854ms step_avg:52.27ms -step:5500/20000 train_loss:2.0701 train_time:287454ms step_avg:52.26ms -step:5550/20000 train_loss:2.1761 train_time:290052ms step_avg:52.26ms -step:5600/20000 train_loss:2.0441 train_time:292763ms step_avg:52.28ms -step:5650/20000 train_loss:2.1013 train_time:295371ms step_avg:52.28ms -step:5700/20000 train_loss:2.1353 train_time:297974ms step_avg:52.28ms -step:5750/20000 train_loss:2.1457 train_time:300579ms step_avg:52.27ms -step:5800/20000 train_loss:2.1356 train_time:303181ms step_avg:52.27ms -step:5850/20000 train_loss:2.0896 train_time:305889ms step_avg:52.29ms -step:5900/20000 train_loss:2.2131 train_time:308488ms step_avg:52.29ms -step:5950/20000 train_loss:1.9779 train_time:311089ms step_avg:52.28ms -step:6000/20000 train_loss:2.0487 train_time:313682ms step_avg:52.28ms -step:6000/20000 val_loss:2.0880 val_bpb:1.2366 train_time:313709ms step_avg:52.28ms -step:6050/20000 train_loss:2.0857 train_time:316281ms step_avg:52.28ms -step:6100/20000 train_loss:2.0682 train_time:318877ms step_avg:52.27ms -step:6150/20000 train_loss:2.0564 train_time:321572ms step_avg:52.29ms -step:6200/20000 train_loss:2.0141 train_time:324175ms step_avg:52.29ms -step:6250/20000 train_loss:2.1140 train_time:326773ms step_avg:52.28ms -step:6300/20000 train_loss:2.0920 train_time:329369ms step_avg:52.28ms -step:6350/20000 train_loss:2.1695 train_time:331965ms step_avg:52.28ms -step:6400/20000 train_loss:2.0887 train_time:334656ms step_avg:52.29ms -step:6450/20000 train_loss:2.0925 train_time:337250ms step_avg:52.29ms -step:6500/20000 train_loss:2.1196 train_time:339844ms step_avg:52.28ms -step:6550/20000 train_loss:2.2899 train_time:342438ms step_avg:52.28ms -step:6600/20000 train_loss:2.0384 train_time:345033ms step_avg:52.28ms -step:6650/20000 train_loss:2.0170 train_time:347728ms step_avg:52.29ms -step:6700/20000 train_loss:2.1356 train_time:350323ms step_avg:52.29ms -step:6750/20000 train_loss:1.9176 train_time:352920ms step_avg:52.28ms -step:6800/20000 train_loss:1.9066 train_time:355518ms step_avg:52.28ms -step:6850/20000 train_loss:1.9774 train_time:358115ms step_avg:52.28ms -step:6900/20000 train_loss:2.0591 train_time:360817ms step_avg:52.29ms -step:6950/20000 train_loss:1.9967 train_time:363414ms step_avg:52.29ms -step:7000/20000 train_loss:2.0529 train_time:366003ms step_avg:52.29ms -step:7000/20000 val_loss:2.0760 val_bpb:1.2295 train_time:366035ms step_avg:52.29ms -step:7050/20000 train_loss:2.1090 train_time:368603ms step_avg:52.28ms -step:7100/20000 train_loss:2.1851 train_time:371201ms step_avg:52.28ms -step:7150/20000 train_loss:2.0422 train_time:373893ms step_avg:52.29ms -step:7200/20000 train_loss:2.1037 train_time:376483ms step_avg:52.29ms -step:7250/20000 train_loss:2.0128 train_time:379076ms step_avg:52.29ms -step:7300/20000 train_loss:2.1154 train_time:381668ms step_avg:52.28ms -step:7350/20000 train_loss:2.0027 train_time:384261ms step_avg:52.28ms -step:7400/20000 train_loss:2.1242 train_time:386965ms step_avg:52.29ms -step:7450/20000 train_loss:2.0912 train_time:389554ms step_avg:52.29ms -step:7500/20000 train_loss:2.0547 train_time:392149ms step_avg:52.29ms -step:7550/20000 train_loss:2.0165 train_time:394738ms step_avg:52.28ms -step:7600/20000 train_loss:1.9748 train_time:397328ms step_avg:52.28ms -step:7650/20000 train_loss:2.0099 train_time:400017ms step_avg:52.29ms -step:7700/20000 train_loss:2.0532 train_time:402608ms step_avg:52.29ms -step:7750/20000 train_loss:2.0572 train_time:405196ms step_avg:52.28ms -step:7800/20000 train_loss:1.9216 train_time:407789ms step_avg:52.28ms -step:7850/20000 train_loss:2.0621 train_time:410383ms step_avg:52.28ms -step:7900/20000 train_loss:2.1483 train_time:413081ms step_avg:52.29ms -step:7950/20000 train_loss:1.9810 train_time:415678ms step_avg:52.29ms -step:8000/20000 train_loss:2.0355 train_time:418271ms step_avg:52.28ms -step:8000/20000 val_loss:2.0678 val_bpb:1.2246 train_time:418298ms step_avg:52.29ms -step:8050/20000 train_loss:2.0104 train_time:420857ms step_avg:52.28ms -step:8100/20000 train_loss:2.0138 train_time:423449ms step_avg:52.28ms -step:8150/20000 train_loss:2.0172 train_time:426134ms step_avg:52.29ms -step:8200/20000 train_loss:2.0736 train_time:428721ms step_avg:52.28ms -step:8250/20000 train_loss:2.0940 train_time:431308ms step_avg:52.28ms -step:8300/20000 train_loss:2.1151 train_time:433896ms step_avg:52.28ms -step:8350/20000 train_loss:2.0713 train_time:436485ms step_avg:52.27ms -step:8400/20000 train_loss:2.0671 train_time:439177ms step_avg:52.28ms -step:8450/20000 train_loss:2.0876 train_time:441765ms step_avg:52.28ms -step:8500/20000 train_loss:1.9228 train_time:444353ms step_avg:52.28ms -step:8550/20000 train_loss:2.0541 train_time:446943ms step_avg:52.27ms -step:8600/20000 train_loss:2.0610 train_time:449531ms step_avg:52.27ms -step:8650/20000 train_loss:2.1088 train_time:452222ms step_avg:52.28ms -step:8700/20000 train_loss:1.9899 train_time:454818ms step_avg:52.28ms -step:8750/20000 train_loss:2.1527 train_time:457407ms step_avg:52.28ms -step:8800/20000 train_loss:2.0140 train_time:459998ms step_avg:52.27ms -step:8850/20000 train_loss:2.1463 train_time:462591ms step_avg:52.27ms -step:8900/20000 train_loss:2.0604 train_time:465180ms step_avg:52.27ms -step:8950/20000 train_loss:1.9942 train_time:467862ms step_avg:52.28ms -step:9000/20000 train_loss:1.9742 train_time:470453ms step_avg:52.27ms -step:9000/20000 val_loss:2.0529 val_bpb:1.2158 train_time:470487ms step_avg:52.28ms -step:9050/20000 train_loss:2.0073 train_time:473038ms step_avg:52.27ms -step:9100/20000 train_loss:2.1509 train_time:475627ms step_avg:52.27ms -step:9150/20000 train_loss:1.9811 train_time:478216ms step_avg:52.26ms -step:9200/20000 train_loss:2.0541 train_time:480916ms step_avg:52.27ms -step:9250/20000 train_loss:2.0177 train_time:483526ms step_avg:52.27ms -step:9300/20000 train_loss:2.1091 train_time:486114ms step_avg:52.27ms -step:9350/20000 train_loss:2.1775 train_time:488705ms step_avg:52.27ms -step:9400/20000 train_loss:2.0170 train_time:491295ms step_avg:52.27ms -step:9450/20000 train_loss:2.0597 train_time:493972ms step_avg:52.27ms -step:9500/20000 train_loss:2.0508 train_time:496558ms step_avg:52.27ms -step:9550/20000 train_loss:2.0011 train_time:499146ms step_avg:52.27ms -step:9600/20000 train_loss:1.9192 train_time:501732ms step_avg:52.26ms -step:9650/20000 train_loss:2.0606 train_time:504320ms step_avg:52.26ms -step:9700/20000 train_loss:2.1116 train_time:506988ms step_avg:52.27ms -step:9750/20000 train_loss:2.0981 train_time:509574ms step_avg:52.26ms -step:9800/20000 train_loss:1.8778 train_time:512165ms step_avg:52.26ms -step:9850/20000 train_loss:1.9928 train_time:514755ms step_avg:52.26ms -step:9900/20000 train_loss:2.1629 train_time:517347ms step_avg:52.26ms -step:9950/20000 train_loss:2.0106 train_time:520035ms step_avg:52.26ms -step:10000/20000 train_loss:2.2900 train_time:522641ms step_avg:52.26ms -step:10000/20000 val_loss:2.0325 val_bpb:1.2038 train_time:522672ms step_avg:52.27ms -step:10050/20000 train_loss:2.0670 train_time:525233ms step_avg:52.26ms -step:10100/20000 train_loss:2.0256 train_time:527817ms step_avg:52.26ms -step:10150/20000 train_loss:1.8322 train_time:530407ms step_avg:52.26ms -step:10200/20000 train_loss:1.9949 train_time:533099ms step_avg:52.26ms -step:10250/20000 train_loss:2.0141 train_time:535689ms step_avg:52.26ms -step:10300/20000 train_loss:2.2751 train_time:538284ms step_avg:52.26ms -step:10350/20000 train_loss:2.0744 train_time:540879ms step_avg:52.26ms -step:10400/20000 train_loss:2.1243 train_time:543470ms step_avg:52.26ms -step:10450/20000 train_loss:1.9712 train_time:546158ms step_avg:52.26ms -step:10500/20000 train_loss:2.0482 train_time:548753ms step_avg:52.26ms -step:10550/20000 train_loss:2.0679 train_time:551348ms step_avg:52.26ms -step:10600/20000 train_loss:1.9611 train_time:553938ms step_avg:52.26ms -step:10650/20000 train_loss:2.0545 train_time:556531ms step_avg:52.26ms -step:10700/20000 train_loss:2.1526 train_time:559228ms step_avg:52.26ms -step:10750/20000 train_loss:2.0829 train_time:561829ms step_avg:52.26ms -step:10800/20000 train_loss:2.0368 train_time:564426ms step_avg:52.26ms -step:10850/20000 train_loss:2.1171 train_time:567018ms step_avg:52.26ms -step:10900/20000 train_loss:2.1192 train_time:569611ms step_avg:52.26ms -step:10950/20000 train_loss:2.0715 train_time:572309ms step_avg:52.27ms -step:11000/20000 train_loss:1.9841 train_time:574902ms step_avg:52.26ms -step:11000/20000 val_loss:2.0135 val_bpb:1.1925 train_time:574930ms step_avg:52.27ms -step:11050/20000 train_loss:2.1428 train_time:577501ms step_avg:52.26ms -step:11100/20000 train_loss:1.9385 train_time:580093ms step_avg:52.26ms -step:11150/20000 train_loss:1.9739 train_time:582688ms step_avg:52.26ms -step:11200/20000 train_loss:2.0283 train_time:585381ms step_avg:52.27ms -step:11250/20000 train_loss:1.9992 train_time:587975ms step_avg:52.26ms -step:11300/20000 train_loss:1.9374 train_time:590573ms step_avg:52.26ms -step:11350/20000 train_loss:1.9705 train_time:593170ms step_avg:52.26ms -step:11400/20000 train_loss:2.0143 train_time:595781ms step_avg:52.26ms -step:11450/20000 train_loss:2.1952 train_time:598473ms step_avg:52.27ms -step:11478/20000 val_loss:2.0071 val_bpb:1.1887 train_time:599949ms step_avg:52.27ms -stopping_early: wallclock_cap train_time:599949ms step:11478/20000 -peak memory allocated: 7655 MiB reserved: 8156 MiB -Serialized model: 67224983 bytes -Code size: 54433 bytes -Total submission size: 67279416 bytes -Serialized model int8+zlib: 15824445 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15878878 bytes -final_int8_zlib_roundtrip val_loss:2.0142 val_bpb:1.1929 eval_time:2123ms -final_int8_zlib_roundtrip_exact val_loss:2.01418632 val_bpb:1.19291459 -final_sliding_window_eval stride:64 val_loss:1.9937 val_bpb:1.1808 eval_time:278940ms -final_sliding_window_eval_exact stride:64 val_loss:1.99374987 val_bpb:1.18081285 From e109841147574212e7528af4039840f6ae33f743 Mon Sep 17 00:00:00 2001 From: aquariouseworkman Date: Thu, 19 Mar 2026 23:26:35 -0400 Subject: [PATCH 3/3] SmearGate + OrthoInit + Muon WD + Int6 STE QAT + MLP 3x + Sliding Window MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit val_bpb: 1.1556 (post-quant int6+zstd-22, sliding window eval stride=64) Summary A 22.4M parameter transformer language model trained in under 10 minutes on 8×H100 GPUs, compressed to a 15.1MB artifact via int6 quantization-aware training and zstd-22. The architecture combines a SmearGate bigram embedding layer, orthogonal weight initialization, 3× MLP expansion, U-Net skip connections, and decoupled Muon weight decay, evaluated with sliding window context at stride 64. Architecture Transformer Core A 9-layer, 512-dim transformer with 8 attention heads (4 KV heads via grouped-query attention) and tied input/output embeddings over a 1024-token BPE vocabulary. Sequence length during training is 1024 tokens. SmearGate A learned per-dimension gate (~512 params) that blends each token's embedding with the previous token's embedding before the transformer processes anything: ```python gate = sigmoid(self.gate) # shape \[dim], init ≈ 0.95 output = gate \* current\_emb + (1 - gate) \* prev\_token\_emb ``` This injects bigram (two-token) context directly into the embedding layer. Normally a transformer must discover token-pair relationships through self-attention; SmearGate provides this signal for free. The gate is initialized via `sigmoid(3.0) ≈ 0.95` so it starts near-identity (mostly current token), and the model learns per-dimension how much previous-token blending is useful. Applied after embedding lookup and bigram hash addition, before RMS normalization. Bigram Hash Embedding A 4096-bucket hash table (dim=128, projected to 512) maps consecutive token pairs to learned embeddings via `(prev \* 92821 + cur) % 4096`. This gives the model direct access to token-pair features at minimal parameter cost. MLP 3× Expansion MLP hidden dimension is 3× the model dimension (1536 for a 512-dim model). The space savings from int6 quantization fund this extra capacity — wider MLPs allow more expressive nonlinear feature transformation between attention operations. U-Net Skip Connections The 9-layer transformer is split into an encoder half (4 layers) and a decoder half (5 layers) with learned skip weights connecting corresponding encoder/decoder layers. This gives the decoder direct access to earlier representations without relying solely on the residual stream. Training Muon Optimizer with Weight Decay The Muon optimizer (MomentUm Orthogonalized by Newton-Schulz) runs SGD with Nesterov momentum, then post-processes each 2D parameter's gradient update by replacing it with the nearest orthogonal matrix via 5-step Newton-Schulz iteration. This is equivalent to steepest descent under the spectral norm, improving the conditioning of the optimization landscape. Decoupled weight decay (`p.mul\_(1 - wd \* lr)`, wd=0.01) is applied before each gradient update. This keeps weights smaller and better-distributed, which directly benefits both generalization and downstream quantization — tighter weight distributions quantize into fewer int6 buckets with less error and compress better with zstd. Momentum is warmed from 0.92 → 0.99 over the first 1500 steps. Orthogonal Weight Initialization All non-zero-init CastedLinear weight matrices are initialized with `nn.init.orthogonal\_()`. Orthogonal matrices have all singular values equal to 1, meaning gradients flow uniformly through the network at initialization with no vanishing or exploding signals. Additionally, since Muon's Newton-Schulz step orthogonalizes updates, starting from an already-orthogonal matrix means early updates are immediately useful rather than spent correcting a random initialization. With only ~12k steps in the 10-minute budget, faster convergence matters. Int6 Quantization-Aware Training (STE) All 2D weight matrices are fake-quantized to int6 ([-31, 31]) during every forward pass via Straight-Through Estimator — the forward pass sees quantized weights while gradients flow through the rounding operation as if it were identity. The model learns weight configurations that are inherently robust to post-training quantization. The tied embedding matrix is stored as fp16 passthrough (not quantized), since it serves double duty for both input embeddings and output predictions where errors compound in both directions. Learning Rate Schedule Warmup over 20 steps, followed by linear warmdown over the final 3000 steps. Separate learning rates for tied embeddings (0.030), matrix parameters (0.020), and scalar parameters (0.020). Evaluation Sliding Window (stride=64) Instead of chopping validation text into non-overlapping chunks (where tokens near the start of each chunk lack context), sliding window uses overlapping windows with stride 64 and the full 1024-token context window. Each scored token gets 960+ tokens of prior context. This is purely an evaluation-time technique — it does not change the model. Export Int6 + zstd-22 Compression All quantized weights are packed into int8 containers and compressed with zstandard at level 22. The int6 representation plus aggressive compression brings the full submission (model + code) to 15.1MB, under the 16MB cap. Metrics Metric Value Post-quant sliding window val_bpb 1.1556 Post-quant sliding window val_loss 1.9511 Post-quant standard val_bpb 1.1891 Post-quant standard val_loss 2.0077 Quantization gap (standard eval) ~0.0001 BPB Model parameters 22,368,840 Artifact size (int6+zstd-22) 15,878,809 bytes (15.1 MB) Train steps completed 12,047 Train time 600s (10.0 min) Sliding window eval time 75s Peak GPU memory 11,340 MiB Configuration ``` VOCAB\_SIZE=1024 NUM\_LAYERS=9 MODEL\_DIM=512 NUM\_HEADS=8 NUM\_KV\_HEADS=4 MLP\_MULT=3 TIE\_EMBEDDINGS=1 USE\_SMEARGATE=1 TRAIN\_SEQ\_LEN=1024 TRAIN\_BATCH\_TOKENS=524288 LOGIT\_SOFTCAP=30.0 ROPE\_BASE=10000.0 QK\_GAIN\_INIT=1.5 BIGRAM\_HASH\_BUCKETS=4096 BIGRAM\_HASH\_DIM=128 TIED\_EMBED\_LR=0.030 MATRIX\_LR=0.020 SCALAR\_LR=0.020 MUON\_MOMENTUM=0.99 MUON\_MOMENTUM\_WARMUP\_START=0.92 MUON\_MOMENTUM\_WARMUP\_STEPS=1500 MUON\_WEIGHT\_DECAY=0.01 MUON\_BACKEND\_STEPS=5 WARMDOWN\_ITERS=3000 WARMUP\_STEPS=20 EVAL\_STRIDE=64 MAX\_WALLCLOCK\_SECONDS=600 SEED=1337 ``` Command ```bash RUN\_ID=smeargate\_orthoinit\_muonwd \\ DATA\_PATH=./data/datasets/fineweb10B\_sp1024 \\ TOKENIZER\_PATH=./data/tokenizers/fineweb\_1024\_bpe.model \\ torchrun --standalone --nproc\_per\_node=8 train\_gpt.py ``` Hardware 8× NVIDIA H100 80GB HBM3 SXM (RunPod). --- .../train.log | 1540 ++++++++++++++++ .../README.md | 137 ++ .../submission.json | 29 + .../train.log | 1594 +++++++++++++++++ .../train_gpt_v5.py | 1422 +++++++++++++++ 5 files changed, 4722 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log create mode 100644 records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md create mode 100644 records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json create mode 100644 records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log create mode 100644 records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py diff --git a/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log b/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log new file mode 100644 index 0000000000..505225703f --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log @@ -0,0 +1,1540 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # BigramHash: inject token-pair context via a hash-table embedding. + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT6_CLIP_Q = 0.9999984 + +# Tensors matching these patterns are stored as fp16 passthrough (not quantized). +# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). +FP16_PASSTHROUGH_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "FP16_PASSTHROUGH_PATTERNS", + "tok_emb,bigram_hash", + ).split(",") + if pattern +) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Int6 per-row quantization: [-31, 31] range stored in int8 containers. + # The unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). + if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class BigramHash(nn.Module): + """Hash-table embedding for token bigrams, projected to model dim. + Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. + Gives the model cheap character-pair / bigram info before attention. + """ + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.table = nn.Embedding(num_buckets, hash_dim) + self.proj = CastedLinear(hash_dim, model_dim, bias=False) + self.proj._zero_init = True + nn.init.normal_(self.table.weight, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + bsz, seqlen = input_ids.shape + prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) + h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() + return self.proj(self.table(h)) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_hash_buckets: int = 0, + bigram_hash_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # Collect embedding-like params + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.table.weight) + # bigram_hash.proj is a CastedLinear — its weight goes to Muon + matrix_params.append(base_model.bigram_hash.proj.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri Mar 20 02:18:24 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 34C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 31C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 49921 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 49922 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 49923 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 49924 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 49925 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 49926 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 49927 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 49928 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:22368328 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9364 train_time:31ms step_avg:30.55ms +step:2/20000 train_loss:12.1409 train_time:82ms step_avg:40.86ms +step:3/20000 train_loss:7.1968 train_time:135ms step_avg:44.91ms +step:4/20000 train_loss:6.4291 train_time:190ms step_avg:47.51ms +step:5/20000 train_loss:6.9210 train_time:245ms step_avg:48.95ms +step:6/20000 train_loss:7.5366 train_time:301ms step_avg:50.13ms +step:7/20000 train_loss:6.7198 train_time:355ms step_avg:50.65ms +step:8/20000 train_loss:6.3634 train_time:409ms step_avg:51.09ms +step:9/20000 train_loss:6.2161 train_time:464ms step_avg:51.52ms +step:10/20000 train_loss:6.0286 train_time:517ms step_avg:51.68ms +step:200/20000 train_loss:2.7939 train_time:9860ms step_avg:49.30ms +step:400/20000 train_loss:2.2882 train_time:19689ms step_avg:49.22ms +step:600/20000 train_loss:2.4945 train_time:29565ms step_avg:49.28ms +step:800/20000 train_loss:2.2566 train_time:39444ms step_avg:49.30ms +step:1000/20000 train_loss:2.3514 train_time:49340ms step_avg:49.34ms +step:1200/20000 train_loss:2.3707 train_time:59244ms step_avg:49.37ms +step:1400/20000 train_loss:2.4180 train_time:69137ms step_avg:49.38ms +step:1600/20000 train_loss:2.0918 train_time:79026ms step_avg:49.39ms +step:1800/20000 train_loss:2.1879 train_time:88911ms step_avg:49.40ms +step:2000/20000 train_loss:2.2381 train_time:98804ms step_avg:49.40ms +step:2200/20000 train_loss:2.0439 train_time:108685ms step_avg:49.40ms +step:2400/20000 train_loss:2.1727 train_time:118547ms step_avg:49.39ms +step:2600/20000 train_loss:2.3791 train_time:128432ms step_avg:49.40ms +step:2800/20000 train_loss:2.1977 train_time:138299ms step_avg:49.39ms +step:3000/20000 train_loss:2.1923 train_time:148160ms step_avg:49.39ms +step:3200/20000 train_loss:2.1539 train_time:158049ms step_avg:49.39ms +step:3400/20000 train_loss:2.1235 train_time:167930ms step_avg:49.39ms +step:3600/20000 train_loss:2.0735 train_time:177813ms step_avg:49.39ms +step:3800/20000 train_loss:2.1821 train_time:187714ms step_avg:49.40ms +step:4000/20000 train_loss:2.1263 train_time:197620ms step_avg:49.41ms +step:4000/20000 val_loss:2.1281 val_bpb:1.2604 train_time:197653ms step_avg:49.41ms +step:4200/20000 train_loss:2.1315 train_time:207621ms step_avg:49.43ms +step:4400/20000 train_loss:2.0687 train_time:217521ms step_avg:49.44ms +step:4600/20000 train_loss:1.9267 train_time:227415ms step_avg:49.44ms +step:4800/20000 train_loss:2.2176 train_time:237314ms step_avg:49.44ms +step:5000/20000 train_loss:1.9865 train_time:247207ms step_avg:49.44ms +step:5200/20000 train_loss:2.1270 train_time:257113ms step_avg:49.44ms +step:5400/20000 train_loss:2.1403 train_time:267008ms step_avg:49.45ms +step:5600/20000 train_loss:2.1378 train_time:276906ms step_avg:49.45ms +step:5800/20000 train_loss:2.0980 train_time:286918ms step_avg:49.47ms +step:6000/20000 train_loss:2.1798 train_time:297160ms step_avg:49.53ms +step:6200/20000 train_loss:2.0437 train_time:307259ms step_avg:49.56ms +step:6400/20000 train_loss:2.1160 train_time:317124ms step_avg:49.55ms +step:6600/20000 train_loss:2.0748 train_time:326994ms step_avg:49.54ms +step:6800/20000 train_loss:2.1477 train_time:336875ms step_avg:49.54ms +step:7000/20000 train_loss:2.1806 train_time:346747ms step_avg:49.54ms +step:7200/20000 train_loss:2.1479 train_time:356620ms step_avg:49.53ms +step:7400/20000 train_loss:2.0732 train_time:366491ms step_avg:49.53ms +step:7600/20000 train_loss:1.9514 train_time:376370ms step_avg:49.52ms +step:7800/20000 train_loss:2.1021 train_time:386252ms step_avg:49.52ms +step:8000/20000 train_loss:2.0685 train_time:396124ms step_avg:49.52ms +step:8000/20000 val_loss:2.0762 val_bpb:1.2297 train_time:396152ms step_avg:49.52ms +step:8200/20000 train_loss:2.1457 train_time:405994ms step_avg:49.51ms +step:8400/20000 train_loss:2.0894 train_time:415948ms step_avg:49.52ms +step:8600/20000 train_loss:2.0913 train_time:425810ms step_avg:49.51ms +step:8800/20000 train_loss:2.0542 train_time:435719ms step_avg:49.51ms +step:9000/20000 train_loss:1.9794 train_time:445595ms step_avg:49.51ms +step:9200/20000 train_loss:2.0395 train_time:455473ms step_avg:49.51ms +step:9400/20000 train_loss:2.0786 train_time:465361ms step_avg:49.51ms +step:9600/20000 train_loss:2.0926 train_time:475238ms step_avg:49.50ms +step:9800/20000 train_loss:2.0138 train_time:485125ms step_avg:49.50ms +step:10000/20000 train_loss:2.0578 train_time:495051ms step_avg:49.51ms +step:10200/20000 train_loss:2.0050 train_time:504951ms step_avg:49.50ms +step:10400/20000 train_loss:2.0338 train_time:514831ms step_avg:49.50ms +step:10600/20000 train_loss:1.9155 train_time:524714ms step_avg:49.50ms +step:10800/20000 train_loss:2.1173 train_time:534599ms step_avg:49.50ms +step:11000/20000 train_loss:2.0472 train_time:544503ms step_avg:49.50ms +step:11200/20000 train_loss:1.9902 train_time:554367ms step_avg:49.50ms +step:11400/20000 train_loss:1.9795 train_time:564228ms step_avg:49.49ms +step:11600/20000 train_loss:1.9797 train_time:574096ms step_avg:49.49ms +step:11800/20000 train_loss:2.0115 train_time:583977ms step_avg:49.49ms +step:12000/20000 train_loss:1.9875 train_time:593849ms step_avg:49.49ms +step:12000/20000 val_loss:2.0165 val_bpb:1.1943 train_time:593877ms step_avg:49.49ms +step:12123/20000 val_loss:2.0158 val_bpb:1.1939 train_time:599995ms step_avg:49.49ms +stopping_early: wallclock_cap train_time:599995ms step:12123/20000 +peak memory allocated: 11273 MiB reserved: 11438 MiB +Serialized model: 87410841 bytes +Code size: 58666 bytes +Total submission size: 87469507 bytes +Serialized model int6+zstd-22: 16127834 bytes (payload:23608608 raw_torch:23654103 payload_ratio:3.70x) +Total submission size int6+zstd-22: 16186500 bytes +final_int6_roundtrip val_loss:2.0145 val_bpb:1.1931 eval_time:1530ms +final_int6_roundtrip_exact val_loss:2.01453600 val_bpb:1.19312169 +final_sliding_window_eval stride:64 val_loss:1.9582 val_bpb:1.1598 eval_time:74131ms +final_sliding_window_eval_exact stride:64 val_loss:1.95823653 val_bpb:1.15977898 diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md new file mode 100644 index 0000000000..843604db97 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md @@ -0,0 +1,137 @@ +# SmearGate + OrthoInit + Muon WD + Int6 STE QAT + MLP 3x + Sliding Window + +**val\_bpb: 1.1556** (post-quant int6+zstd-22, sliding window eval stride=64) + +## Summary + +A 22.4M parameter transformer language model trained in under 10 minutes on 8×H100 GPUs, compressed to a 15.1MB artifact via int6 quantization-aware training and zstd-22. The architecture combines a SmearGate bigram embedding layer, orthogonal weight initialization, 3× MLP expansion, U-Net skip connections, and decoupled Muon weight decay, evaluated with sliding window context at stride 64. + +## Architecture + +### Transformer Core + +A 9-layer, 512-dim transformer with 8 attention heads (4 KV heads via grouped-query attention) and tied input/output embeddings over a 1024-token BPE vocabulary. Sequence length during training is 1024 tokens. + +### SmearGate + +A learned per-dimension gate (\~512 params) that blends each token's embedding with the previous token's embedding before the transformer processes anything: + +```python +gate = sigmoid(self.gate) # shape \[dim], init ≈ 0.95 +output = gate \* current\_emb + (1 - gate) \* prev\_token\_emb +``` + +This injects bigram (two-token) context directly into the embedding layer. Normally a transformer must discover token-pair relationships through self-attention; SmearGate provides this signal for free. The gate is initialized via `sigmoid(3.0) ≈ 0.95` so it starts near-identity (mostly current token), and the model learns per-dimension how much previous-token blending is useful. + +Applied after embedding lookup and bigram hash addition, before RMS normalization. + +### Bigram Hash Embedding + +A 4096-bucket hash table (dim=128, projected to 512) maps consecutive token pairs to learned embeddings via `(prev \* 92821 + cur) % 4096`. This gives the model direct access to token-pair features at minimal parameter cost. + +### MLP 3× Expansion + +MLP hidden dimension is 3× the model dimension (1536 for a 512-dim model). The space savings from int6 quantization fund this extra capacity — wider MLPs allow more expressive nonlinear feature transformation between attention operations. + +### U-Net Skip Connections + +The 9-layer transformer is split into an encoder half (4 layers) and a decoder half (5 layers) with learned skip weights connecting corresponding encoder/decoder layers. This gives the decoder direct access to earlier representations without relying solely on the residual stream. + +## Training + +### Muon Optimizer with Weight Decay + +The Muon optimizer (MomentUm Orthogonalized by Newton-Schulz) runs SGD with Nesterov momentum, then post-processes each 2D parameter's gradient update by replacing it with the nearest orthogonal matrix via 5-step Newton-Schulz iteration. This is equivalent to steepest descent under the spectral norm, improving the conditioning of the optimization landscape. + +Decoupled weight decay (`p.mul\_(1 - wd \* lr)`, wd=0.01) is applied before each gradient update. This keeps weights smaller and better-distributed, which directly benefits both generalization and downstream quantization — tighter weight distributions quantize into fewer int6 buckets with less error and compress better with zstd. + +Momentum is warmed from 0.92 → 0.99 over the first 1500 steps. + +### Orthogonal Weight Initialization + +All non-zero-init CastedLinear weight matrices are initialized with `nn.init.orthogonal\_()`. Orthogonal matrices have all singular values equal to 1, meaning gradients flow uniformly through the network at initialization with no vanishing or exploding signals. Additionally, since Muon's Newton-Schulz step orthogonalizes updates, starting from an already-orthogonal matrix means early updates are immediately useful rather than spent correcting a random initialization. With only \~12k steps in the 10-minute budget, faster convergence matters. + +### Int6 Quantization-Aware Training (STE) + +All 2D weight matrices are fake-quantized to int6 (\[-31, 31]) during every forward pass via Straight-Through Estimator — the forward pass sees quantized weights while gradients flow through the rounding operation as if it were identity. The model learns weight configurations that are inherently robust to post-training quantization. The tied embedding matrix is stored as fp16 passthrough (not quantized), since it serves double duty for both input embeddings and output predictions where errors compound in both directions. + +### Learning Rate Schedule + +Warmup over 20 steps, followed by linear warmdown over the final 3000 steps. Separate learning rates for tied embeddings (0.030), matrix parameters (0.020), and scalar parameters (0.020). + +## Evaluation + +### Sliding Window (stride=64) + +Instead of chopping validation text into non-overlapping chunks (where tokens near the start of each chunk lack context), sliding window uses overlapping windows with stride 64 and the full 1024-token context window. Each scored token gets 960+ tokens of prior context. This is purely an evaluation-time technique — it does not change the model. + +## Export + +### Int6 + zstd-22 Compression + +All quantized weights are packed into int8 containers and compressed with zstandard at level 22. The int6 representation plus aggressive compression brings the full submission (model + code) to 15.1MB, under the 16MB cap. + +## Metrics + +|Metric|Value| +|-|-| +|**Post-quant sliding window val\_bpb**|**1.1556**| +|Post-quant sliding window val\_loss|1.9511| +|Post-quant standard val\_bpb|1.1891| +|Post-quant standard val\_loss|2.0077| +|Quantization gap (standard eval)|\~0.0001 BPB| +|Model parameters|22,368,840| +|Artifact size (int6+zstd-22)|15,878,809 bytes (15.1 MB)| +|Train steps completed|12,047| +|Train time|600s (10.0 min)| +|Sliding window eval time|75s| +|Peak GPU memory|11,340 MiB| + +## Configuration + +``` +VOCAB\_SIZE=1024 +NUM\_LAYERS=9 +MODEL\_DIM=512 +NUM\_HEADS=8 +NUM\_KV\_HEADS=4 +MLP\_MULT=3 +TIE\_EMBEDDINGS=1 +USE\_SMEARGATE=1 +TRAIN\_SEQ\_LEN=1024 +TRAIN\_BATCH\_TOKENS=524288 +LOGIT\_SOFTCAP=30.0 +ROPE\_BASE=10000.0 +QK\_GAIN\_INIT=1.5 +BIGRAM\_HASH\_BUCKETS=4096 +BIGRAM\_HASH\_DIM=128 +TIED\_EMBED\_LR=0.030 +MATRIX\_LR=0.020 +SCALAR\_LR=0.020 +MUON\_MOMENTUM=0.99 +MUON\_MOMENTUM\_WARMUP\_START=0.92 +MUON\_MOMENTUM\_WARMUP\_STEPS=1500 +MUON\_WEIGHT\_DECAY=0.01 +MUON\_BACKEND\_STEPS=5 +WARMDOWN\_ITERS=3000 +WARMUP\_STEPS=20 +EVAL\_STRIDE=64 +MAX\_WALLCLOCK\_SECONDS=600 +SEED=1337 +``` + +## Command + +```bash +RUN\_ID=smeargate\_orthoinit\_muonwd \\ +DATA\_PATH=./data/datasets/fineweb10B\_sp1024 \\ +TOKENIZER\_PATH=./data/tokenizers/fineweb\_1024\_bpe.model \\ +torchrun --standalone --nproc\_per\_node=8 train\_gpt.py +``` + +## Hardware + +8× NVIDIA H100 80GB HBM3 SXM (RunPod). + +## + diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json new file mode 100644 index 0000000000..99426f02a1 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json @@ -0,0 +1,29 @@ +{ + "val_bpb": 1.1556, + "val_loss": 1.9511, + "eval_method": "sliding_window", + "eval_stride": 64, + "quantization": "int6", + "compression": "zstd-22", + "artifact_size_bytes": 15878809, + "model_params": 22368840, + "train_steps": 12047, + "train_time_seconds": 600, + "hardware": "8xH100-80GB-SXM", + "seed": 1337, + "techniques": [ + "SmearGate", + "Orthogonal Init", + "Muon Weight Decay", + "Int6 STE QAT", + "MLP 3x", + "Sliding Window Eval", + "Bigram Hash Embedding", + "U-Net Skip Connections", + "Muon Momentum Warmup", + "zstd-22 Compression" + ], + "post_quant_standard_val_bpb": 1.1891, + "post_quant_standard_val_loss": 2.0077, + "quantization_gap_bpb": 0.0001 +} diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log new file mode 100644 index 0000000000..71fd53fb13 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log @@ -0,0 +1,1594 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.01)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # BigramHash: inject token-pair context via a hash-table embedding. + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + + # SmearGate: blend each token's embedding with the previous token's + # via a tiny learned gate. Injects bigram context before the transformer. + use_smeargate = bool(int(os.environ.get("USE_SMEARGATE", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group["weight_decay"] + + # Decoupled weight decay: shrink weights toward zero before the + # gradient update. This keeps the weight distribution tight and + # well-centred, which (a) improves generalisation and (b) makes + # post-training int6 quantisation compress better — the tighter + # distribution fits into fewer quantisation buckets with less error. + if wd > 0: + for p in params: + p.mul_(1.0 - wd * lr) + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT6_CLIP_Q = 0.9999984 + +# Tensors matching these patterns are stored as fp16 passthrough (not quantized). +# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). +FP16_PASSTHROUGH_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "FP16_PASSTHROUGH_PATTERNS", + "tok_emb,bigram_hash", + ).split(",") + if pattern +) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Int6 per-row quantization: [-31, 31] range stored in int8 containers. + # The unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). + if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class BigramHash(nn.Module): + """Hash-table embedding for token bigrams, projected to model dim. + Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. + Gives the model cheap character-pair / bigram info before attention. + """ + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.table = nn.Embedding(num_buckets, hash_dim) + self.proj = CastedLinear(hash_dim, model_dim, bias=False) + self.proj._zero_init = True + nn.init.normal_(self.table.weight, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + bsz, seqlen = input_ids.shape + prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) + h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() + return self.proj(self.table(h)) + + +class SmearGate(nn.Module): + """Learned per-dimension gate blending each token's embedding with the + previous token's. Injects bigram (two-token) context directly into the + embedding layer *before* the transformer starts processing. + + Normally a transformer must discover token-pair relationships through + self-attention; SmearGate provides this signal for free at ~dim params. + + Technique originated by @unnir in parameter-golf PR #102/#135. + """ + def __init__(self, dim: int): + super().__init__() + # Initialise so sigmoid(gate) ≈ 0.95 → mostly pass-through at init, + # with a small amount of previous-token blending that the model can + # learn to increase or decrease per dimension. + self.gate = nn.Parameter(torch.full((dim,), 3.0, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate).to(dtype=x.dtype) + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return g * x + (1.0 - g) * x_prev + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_hash_buckets: int = 0, + bigram_hash_dim: int = 128, + use_smeargate: bool = True, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None + self.smeargate = SmearGate(model_dim) if use_smeargate else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + else: + # Orthogonal init: produces well-conditioned weight matrices + # whose singular values are all 1. This gives the model a + # better starting point — gradients flow more uniformly + # through orthogonal matrices, so early training steps are + # more informative, which matters when you only get ~12k + # steps in the 10-minute budget. + nn.init.orthogonal_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + if self.smeargate is not None: + x = self.smeargate(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + if self.smeargate is not None: + x = self.smeargate(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + use_smeargate=args.use_smeargate, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # Collect embedding-like params + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.table.weight) + # bigram_hash.proj is a CastedLinear — its weight goes to Muon + matrix_params.append(base_model.bigram_hash.proj.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + # SmearGate gate is a 1D parameter → goes to scalar optimizer + if base_model.smeargate is not None: + scalar_params.append(base_model.smeargate.gate) + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri Mar 20 02:59:26 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 35C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 63083 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 63084 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 63085 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 63086 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 63087 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 63088 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 63089 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 63090 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:22368840 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9360 train_time:54ms step_avg:54.31ms +step:2/20000 train_loss:11.9110 train_time:95ms step_avg:47.40ms +step:3/20000 train_loss:6.9870 train_time:150ms step_avg:49.93ms +step:4/20000 train_loss:6.5188 train_time:207ms step_avg:51.71ms +step:5/20000 train_loss:7.0146 train_time:256ms step_avg:51.19ms +step:6/20000 train_loss:7.6320 train_time:312ms step_avg:52.04ms +step:7/20000 train_loss:6.7726 train_time:366ms step_avg:52.29ms +step:8/20000 train_loss:6.4472 train_time:424ms step_avg:52.99ms +step:9/20000 train_loss:6.2283 train_time:476ms step_avg:52.87ms +step:10/20000 train_loss:6.0399 train_time:533ms step_avg:53.32ms +step:200/20000 train_loss:2.8733 train_time:9951ms step_avg:49.75ms +step:400/20000 train_loss:2.3103 train_time:19855ms step_avg:49.64ms +step:600/20000 train_loss:2.5058 train_time:29786ms step_avg:49.64ms +step:800/20000 train_loss:2.2640 train_time:39699ms step_avg:49.62ms +step:1000/20000 train_loss:2.3549 train_time:49633ms step_avg:49.63ms +step:1200/20000 train_loss:2.3742 train_time:59568ms step_avg:49.64ms +step:1400/20000 train_loss:2.4250 train_time:69496ms step_avg:49.64ms +step:1600/20000 train_loss:2.0981 train_time:79428ms step_avg:49.64ms +step:1800/20000 train_loss:2.1951 train_time:89342ms step_avg:49.63ms +step:2000/20000 train_loss:2.2342 train_time:99282ms step_avg:49.64ms +step:2200/20000 train_loss:2.0520 train_time:109218ms step_avg:49.64ms +step:2400/20000 train_loss:2.1752 train_time:119147ms step_avg:49.64ms +step:2600/20000 train_loss:2.3816 train_time:129101ms step_avg:49.65ms +step:2800/20000 train_loss:2.2040 train_time:139034ms step_avg:49.65ms +step:3000/20000 train_loss:2.1976 train_time:148962ms step_avg:49.65ms +step:3200/20000 train_loss:2.1568 train_time:158936ms step_avg:49.67ms +step:3400/20000 train_loss:2.1293 train_time:168883ms step_avg:49.67ms +step:3600/20000 train_loss:2.0805 train_time:178828ms step_avg:49.67ms +step:3800/20000 train_loss:2.1861 train_time:188749ms step_avg:49.67ms +step:4000/20000 train_loss:2.1296 train_time:198681ms step_avg:49.67ms +step:4200/20000 train_loss:2.1472 train_time:208700ms step_avg:49.69ms +step:4400/20000 train_loss:2.0760 train_time:218620ms step_avg:49.69ms +step:4600/20000 train_loss:1.9327 train_time:228570ms step_avg:49.69ms +step:4800/20000 train_loss:2.2263 train_time:238524ms step_avg:49.69ms +step:5000/20000 train_loss:1.9942 train_time:248446ms step_avg:49.69ms +step:5200/20000 train_loss:2.1393 train_time:258372ms step_avg:49.69ms +step:5400/20000 train_loss:2.1549 train_time:268304ms step_avg:49.69ms +step:5600/20000 train_loss:2.1469 train_time:278242ms step_avg:49.69ms +step:5800/20000 train_loss:2.1108 train_time:288216ms step_avg:49.69ms +step:6000/20000 train_loss:2.1936 train_time:298205ms step_avg:49.70ms +step:6200/20000 train_loss:2.0551 train_time:308176ms step_avg:49.71ms +step:6400/20000 train_loss:2.1305 train_time:318155ms step_avg:49.71ms +step:6600/20000 train_loss:2.0871 train_time:328123ms step_avg:49.72ms +step:6800/20000 train_loss:2.1626 train_time:338112ms step_avg:49.72ms +step:7000/20000 train_loss:2.1948 train_time:348088ms step_avg:49.73ms +step:7200/20000 train_loss:2.1656 train_time:358079ms step_avg:49.73ms +step:7400/20000 train_loss:2.0888 train_time:368049ms step_avg:49.74ms +step:7600/20000 train_loss:1.9626 train_time:378011ms step_avg:49.74ms +step:7800/20000 train_loss:2.1138 train_time:387989ms step_avg:49.74ms +step:8000/20000 train_loss:2.0847 train_time:397958ms step_avg:49.74ms +step:8200/20000 train_loss:2.1603 train_time:407935ms step_avg:49.75ms +step:8400/20000 train_loss:2.1027 train_time:418010ms step_avg:49.76ms +step:8600/20000 train_loss:2.1065 train_time:427980ms step_avg:49.77ms +step:8800/20000 train_loss:2.0714 train_time:437943ms step_avg:49.77ms +step:9000/20000 train_loss:1.9954 train_time:447930ms step_avg:49.77ms +step:9200/20000 train_loss:2.0574 train_time:458197ms step_avg:49.80ms +step:9400/20000 train_loss:2.0975 train_time:468123ms step_avg:49.80ms +step:9600/20000 train_loss:2.1099 train_time:478053ms step_avg:49.80ms +step:9800/20000 train_loss:2.0284 train_time:487968ms step_avg:49.79ms +step:10000/20000 train_loss:2.0685 train_time:497890ms step_avg:49.79ms +step:10200/20000 train_loss:2.0115 train_time:507802ms step_avg:49.78ms +step:10400/20000 train_loss:2.0397 train_time:517716ms step_avg:49.78ms +step:10600/20000 train_loss:1.9194 train_time:527650ms step_avg:49.78ms +step:10800/20000 train_loss:2.1216 train_time:537617ms step_avg:49.78ms +step:11000/20000 train_loss:2.0445 train_time:547588ms step_avg:49.78ms +step:11200/20000 train_loss:1.9996 train_time:557571ms step_avg:49.78ms +step:11400/20000 train_loss:1.9753 train_time:567547ms step_avg:49.78ms +step:11600/20000 train_loss:1.9778 train_time:577534ms step_avg:49.79ms +step:11800/20000 train_loss:2.0026 train_time:587535ms step_avg:49.79ms +step:12000/20000 train_loss:1.9772 train_time:597520ms step_avg:49.79ms +step:12047/20000 val_loss:2.0079 val_bpb:1.1892 train_time:599974ms step_avg:49.80ms +stopping_early: wallclock_cap train_time:599974ms step:12047/20000 +peak memory allocated: 11340 MiB reserved: 11504 MiB +Serialized model: 87413210 bytes +Code size: 61629 bytes +Total submission size: 87474839 bytes +Serialized model int6+zstd-22: 15817180 bytes (payload:23609632 raw_torch:23655445 payload_ratio:3.70x) +Total submission size int6+zstd-22: 15878809 bytes +final_int6_roundtrip val_loss:2.0077 val_bpb:1.1891 eval_time:1553ms +final_int6_roundtrip_exact val_loss:2.00771951 val_bpb:1.18908459 +final_sliding_window_eval stride:64 val_loss:1.9511 val_bpb:1.1556 eval_time:75096ms +final_sliding_window_eval_exact stride:64 val_loss:1.95110428 val_bpb:1.15555486 diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py new file mode 100644 index 0000000000..ed29085c15 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py @@ -0,0 +1,1422 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window eval: stride controls how many tokens to slide between windows. + # Only the last `eval_stride` tokens per window are scored, giving each scored token + # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.01)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # BigramHash: inject token-pair context via a hash-table embedding. + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + + # SmearGate: blend each token's embedding with the previous token's + # via a tiny learned gate. Injects bigram context before the transformer. + use_smeargate = bool(int(os.environ.get("USE_SMEARGATE", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group["weight_decay"] + + # Decoupled weight decay: shrink weights toward zero before the + # gradient update. This keeps the weight distribution tight and + # well-centred, which (a) improves generalisation and (b) makes + # post-training int6 quantisation compress better — the tighter + # distribution fits into fewer quantisation buckets with less error. + if wd > 0: + for p in params: + p.mul_(1.0 - wd * lr) + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + """Sliding window evaluation: each scored token gets (seq_len - stride) context. + + Instead of chopping validation into non-overlapping 1024-token blocks (where + the first token in each block gets zero context), we slide a 1024-token window + by `stride` tokens at a time and only score the last `stride` tokens per window. + Every scored token sees 960+ tokens of context, dramatically improving BPB. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 # need 1 extra for final target + + # Number of windows: first window at offset 0, then slide by stride + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + + # Distribute windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Eval batch: how many windows per forward pass. Tune for memory. + eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + base_model.eval() + with torch.inference_mode(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + eval_batch - 1) // eval_batch + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] + bsz = len(batch_wins) + + # Build input: each window is val_tokens[w*stride : w*stride + seq_len] + inputs = torch.stack([ + val_tokens[w * stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64) # [B, seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] + + # Score only the last `stride` positions per window. + # logits[:, j, :] predicts the token AFTER position j in the input. + # So logits[:, -stride:, :] predicts tokens at input positions + # [seq_len - stride + 1, seq_len + 1) — which are the targets. + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] + + # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] + targets = torch.stack([ + val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + # BPB: prev_ids are the input tokens at each scored position + prev_ids = torch.stack([ + val_tokens[w * stride + seq_len - stride : w * stride + seq_len] + for w in batch_wins + ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] + + tgt_ids = targets + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_idx + 1) % 100 == 0: + print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 + +INT6_QUANT_RANGE = 31 # int6: [-31, 31] +INT6_CLIP_Q = 0.9999984 + +# Tensors matching these patterns are stored as fp16 passthrough (not quantized). +# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). +FP16_PASSTHROUGH_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "FP16_PASSTHROUGH_PATTERNS", + "tok_emb,bigram_hash", + ).split(",") + if pattern +) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Int6 per-row quantization: [-31, 31] range stored in int8 containers. + # The unused high bits compress extremely well with zstd. + clip_abs = ( + torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars still use int8 per-tensor scale (they're tiny). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). + if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # During training, apply fake int6 quantization (STE) so the model learns to be robust + # to the post-training quantization that will be applied for the 16MB artifact. + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.training and w.ndim == 2: + # Fake int6 per-row quantization with Straight-Through Estimator: + # Forward uses quantized weights, backward passes gradients through as-is. + with torch.no_grad(): + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) + scale = clip_abs / INT6_QUANT_RANGE + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: value of w_q, gradient of w + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class BigramHash(nn.Module): + """Hash-table embedding for token bigrams, projected to model dim. + Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. + Gives the model cheap character-pair / bigram info before attention. + """ + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.table = nn.Embedding(num_buckets, hash_dim) + self.proj = CastedLinear(hash_dim, model_dim, bias=False) + self.proj._zero_init = True + nn.init.normal_(self.table.weight, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + bsz, seqlen = input_ids.shape + prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) + h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() + return self.proj(self.table(h)) + + +class SmearGate(nn.Module): + """Learned per-dimension gate blending each token's embedding with the + previous token's. Injects bigram (two-token) context directly into the + embedding layer *before* the transformer starts processing. + + Normally a transformer must discover token-pair relationships through + self-attention; SmearGate provides this signal for free at ~dim params. + + Technique originated by @unnir in parameter-golf PR #102/#135. + """ + def __init__(self, dim: int): + super().__init__() + # Initialise so sigmoid(gate) ≈ 0.95 → mostly pass-through at init, + # with a small amount of previous-token blending that the model can + # learn to increase or decrease per dimension. + self.gate = nn.Parameter(torch.full((dim,), 3.0, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate).to(dtype=x.dtype) + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return g * x + (1.0 - g) * x_prev + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_hash_buckets: int = 0, + bigram_hash_dim: int = 128, + use_smeargate: bool = True, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None + self.smeargate = SmearGate(model_dim) if use_smeargate else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + else: + # Orthogonal init: produces well-conditioned weight matrices + # whose singular values are all 1. This gives the model a + # better starting point — gradients flow more uniformly + # through orthogonal matrices, so early training steps are + # more informative, which matters when you only get ~12k + # steps in the 10-minute budget. + nn.init.orthogonal_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + if self.smeargate is not None: + x = self.smeargate(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits without computing loss. Used for sliding window eval.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + if self.smeargate is not None: + x = self.smeargate(x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + use_smeargate=args.use_smeargate, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # Collect embedding-like params + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.table.weight) + # bigram_hash.proj is a CastedLinear — its weight goes to Muon + matrix_params.append(base_model.bigram_hash.proj.weight) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + # SmearGate gate is a 1D parameter → goes to scalar optimizer + if base_model.smeargate is not None: + scalar_params.append(base_model.smeargate.gate) + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_name = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_name = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" + with open(f"final_model.{quant_ext}", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + quant_ext = "int6.ptz" + with open(f"final_model.{quant_ext}", "rb") as f: + quant_blob_disk = f.read() + if HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_decompressed = dctx.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval: gives each scored token (seq_len - stride) context. + if args.eval_stride > 0: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window_eval stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0( + f"final_sliding_window_eval_exact stride:{args.eval_stride} " + f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file