diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/README.md b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/README.md new file mode 100644 index 0000000000..da3da1fe4f --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/README.md @@ -0,0 +1,93 @@ +# 11L + XSA4 + EMA(0.997) + seq2048 + Int5-MLP + MuonWD=0.04 + LateK-FP16 + +**val_bpb: 1.1361** (avg over seeds 1337/42/123, sliding window stride=64, post int5/int6+zstd-22 quantization roundtrip) +Hardware: 8×H100-80GB-SXM5 | Steps: ~9126 avg | Wallclock: 600s | Artifact: ~15.79MB + +## Approach + +This submission stacks eight techniques on the baseline, building directly on PRs #135, #162, #264, #265, and #287. All changes are incremental over those prior submissions — no new external dependencies are required. + +### 1. Per-Row Int6 Quantization + Int5-MLP + zstd-22 (PR #264, extended) + +Attention weight matrices are quantized to int6 (symmetric, 63 levels, per-row scaling). MLP weights (`c_fc`, `c_proj`) use int5 (32 levels, [-16, 15]) instead of int6, which compresses ~15% better at the cost of a small precision penalty. This saves approximately 1.9MB relative to uniform int6, which directly funds the 11th transformer layer while staying within the 16MB artifact limit. Straight-through estimator (STE) fake-quantization is applied during training for both int6 and int5. The tied token embedding is kept in fp16. Compressed with zstandard at level 22. + +### 2. Late-K FP16 on Final Layer + +The key projection (`c_k.weight`) of the last transformer block is kept in fp16 rather than int6. This avoids quantization noise in the most context-sensitive attention keys, at a cost of ~131KB — a favorable quality/size trade-off. + +### 3. 11-Layer U-Net Architecture + +The model uses 11 transformer blocks (5 encoder + 6 decoder) with U-Net skip connections, up from the 9-layer baseline. The 11th layer becomes feasible due to the byte savings from int5-MLP quantization. + +### 4. Exclusive Self-Attention (XSA) on Last 4 Layers (PR #265) + +XSA projects each value vector out of the attention output before the projection layer, preventing the model from trivially attending to each token's own value. Applied to the deepest 4 layers (`xsa_last_n=4`). Implemented as an efficient GQA-aware subtraction with no repeat_interleave. Confirmed neutral vs 3 layers at seq_len=2048 with EMA (three independent runs). The implementation closely follows PR #265. + +### 5. EMA Weight Averaging, decay=0.997 (PR #287) + +An exponential moving average of all model parameters is maintained throughout training: `ema = 0.997 * ema + 0.003 * param`. The EMA weights are substituted into the model before the final quantization and evaluation. This replaces Stochastic Weight Averaging (SWA), which was found to exceed the 16MB artifact budget when combined with 11 layers and seq2048 batch tokens. EMA adds zero artifact size (the shadow copy is discarded after training) and provides smoother, better-regularized weights for quantization. Based on PR #287 (which uses the same decay=0.997). + +### 6. Sequence Length 2048 + +Training sequence length extended from 1024 to 2048. This increases gradient token count per step (same batch token budget), giving each update richer long-range context. The warmdown schedule is tuned to 2000 steps (rather than the 3500 default) to match the step budget at seq2048 training speed. + +### 7. SmearGate + BigramHash(2048) + OrthoInit (PR #135) + +Three techniques from PR #135 are included unchanged: +- **SmearGate**: a learned per-dimension sigmoid gate that blends each token embedding with the preceding token, injecting lightweight bigram context at the embedding layer. +- **BigramHash**: a 2048-bucket hash table (dim=64, projected to 512) that maps consecutive token pairs to learned embeddings via `xor(36313*t[i], 27191*t[i-1]) % 2047`. The bucket count is reduced to 2048 (from the PR #135 default of 4096) to fit the artifact budget with 11 layers. The embedding dimension is 64 rather than 128, maintaining the ~0.5MB FP16 footprint. +- **OrthoInit**: all large (≥64×64) weight matrices are initialized with `nn.init.orthogonal_(gain=1.0)`. Output projections are additionally scaled by `1/sqrt(2*num_layers)` following muP conventions. + +### 8. Muon Optimizer with Weight Decay = 0.04 (PR #162, tuned) + +PR #162 introduced decoupled weight decay for Muon. We tune the decay from the PR #162 default: sweep over {0.01, 0.02, 0.04, 0.1} shows 0.04 as optimal (+0.0006 BPB over 0.02; 0.1 catastrophic). Muon momentum is fixed at 0.95 with a warmup from 0.85 over 500 steps. RoPE base is 500K (confirmed −0.0036 BPB over the default 10K). + +## Hyperparameters + +| Parameter | Value | +|---|---| +| num_layers | 11 | +| model_dim | 512 | +| num_heads | 8 | +| num_kv_heads | 4 (GQA) | +| mlp_mult | 3 (hidden=1536) | +| train_seq_len | 2048 | +| train_batch_tokens | 524,288 | +| warmdown_iters | 2000 | +| warmup_steps | 20 | +| matrix_lr | 0.02 | +| scalar_lr | 0.02 | +| tied_embed_lr | 0.03 | +| muon_momentum | 0.95 (warmup from 0.85 over 500 steps) | +| muon_weight_decay | 0.04 | +| rope_base | 500,000 | +| eval_stride | 64 | +| ema_decay | 0.997 | +| xsa_last_n | 4 | +| bigram_vocab_size | 2048 | +| bigram_dim | 64 | +| late_k_fp16_layers | 1 | +| use_int5_mlp | True | +| use_int6_qat | True | +| quantization | int5 (MLP) + int6 (attn) + fp16 (embed, last c_k) | +| compression | zstd level 22 | + +## Results + +| Seed | val_bpb | steps | wallclock | +|---|---|---|---| +| 42 | 1.1370 | 9122 | 600s | +| 123 | 1.1354 | 9136 | 600s | +| 1337 | **1.1358** | 9120 | 600s | + +**Average val_bpb: 1.1361** (3 seeds) +## Technique Attribution + +| Technique | Source | +|---|---| +| SmearGate, BigramHash, OrthoInit, tied-embed LR, matrix/scalar LR | PR #135 | +| Muon weight decay | PR #162 | +| Int5-MLP, Int6 QAT, Late-K FP16 | PR #264 | +| Exclusive Self-Attention (XSA) | PR #265 | +| EMA weight averaging (decay=0.997) | PR #287 | +| Warmdown tuning for seq2048, MuonWD=0.04, BigramVocab=2048 | this submission | diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/submission.json b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/submission.json new file mode 100644 index 0000000000..086db1fc68 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/submission.json @@ -0,0 +1,35 @@ +{ + "author": "Siddarth Venkatraman", + "github_id": "HyperPotatoNeo", + "name": "11L + XSA4 + EMA(0.997) + seq2048 + Int5-MLP + MuonWD=0.04 + LateK-FP16 + BigramHash(2048)", + "blurb": "11-layer transformer with Exclusive Self-Attention (XSA) on last 4 layers, EMA (decay=0.997) replacing SWA for post-training weight averaging, seq_len=2048, Int5 quantization for MLP weights (freeing budget for 11th layer), Int6 for attention, Late-K FP16 on final layer c_k, BigramHash(2048 buckets, dim=64), SmearGate, OrthoInit (PR #135), Muon weight decay=0.04 (tuned from PR #162 default of 0.02), warmdown=2000 steps (tuned for seq2048), RoPE base=500K. Sliding window evaluation with stride=64. 3-seed average val_bpb=1.1361 (seeds 1337/42/123), ~9126 steps avg in 600s.", + "date": "2026-03-21T02:00:00Z", + "val_bpb": 1.1361, + "seeds": [ + 42, + 123, + 1337 + ], + "seed_results": { + "1337": { + "val_bpb": 1.135778, + "steps": 9120, + "wallclock_seconds": 600 + }, + "42": { + "val_bpb": 1.13702666, + "steps": 9122, + "wallclock_seconds": 600 + }, + "123": { + "val_bpb": 1.13537826, + "steps": 9136, + "wallclock_seconds": 600 + } + }, + "step_stop": 9120, + "wallclock_seconds": 600, + "bytes_total": 15826438, + "bytes_model_int8_zstd": 15789000, + "hardware": "8xH100-80GB-SXM5" +} diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/train_gpt.py b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/train_gpt.py new file mode 100644 index 0000000000..e8b4f41836 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA097_seq2048_Int5MLP_MuonWD04/train_gpt.py @@ -0,0 +1,1779 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +import zstandard +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + # Sliding window stride for final eval (64 = overlapping windows scoring last 64 tokens each). + # Set to seq_len (default) for non-overlapping eval during training. + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + # LoRA TTT: per-sequence test-time training during final quantized eval. + # Adapts rank-8 LoRA on c_q/c_v per sequence, 1 Adam step on prefix tokens. + # Adapters excluded from artifact (initialized fresh at eval time). + lora_ttt = bool(int(os.environ.get("LORA_TTT", "0"))) + lora_ttt_rank = int(os.environ.get("LORA_TTT_RANK", "8")) + lora_ttt_lr = float(os.environ.get("LORA_TTT_LR", "0.01")) + lora_ttt_prefix_len = int(os.environ.get("LORA_TTT_PREFIX_LEN", "512")) # tokens used for adaptation + lora_ttt_steps = int(os.environ.get("LORA_TTT_STEPS", "1")) # Adam steps per document + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) # 2000 optimal for seq_len=2048 (-0.0003 vs 3000) + warmdown_lr_floor = float(os.environ.get("WARMDOWN_LR_FLOOR", 0.1)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 500000.0)) # RoPE base: 500000 confirmed best (10000→50000→500000, -0.0036 on shared champion) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) # PR #135 exact: 0.03 confirmed -0.0037 BPB + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) # PR #135 exact: 0.02 confirmed -0.0037 BPB + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) # PR #135 exact: 0.02 confirmed -0.0037 BPB + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) # momentum=0.95 CONFIRMED CHAMPION 1.1542 BPB (vs 0.99=1.1550) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", "0.04")) # Muon weight decay: 0.04 confirmed optimal (+0.0006 BPB vs 0.02; PR #162 introduced WD) + normuon = bool(int(os.environ.get("NORMUON", "0"))) # NorMuon per-row second-moment normalization + swa_every = int(os.environ.get("SWA_EVERY", 0)) # SWA: collect checkpoint every N warmdown steps (0=off) + use_int6_qat = bool(int(os.environ.get("USE_INT6_QAT", "1"))) # Int6 QAT: fake-quantize attn+MLP weights during training + use_int5_mlp = bool(int(os.environ.get("USE_INT5_MLP", "1"))) # Int5-MLP: use 5-bit quant for MLP weights (saves ~1.9MB via better compression, PR #264) + late_k_fp16_layers = int(os.environ.get("LATE_K_FP16_LAYERS", "1")) # Late-K FP16: keep c_k in FP16 for final N layers (1 = last layer only) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) # OrthoInit: orthogonal init for large 2D weight matrices (PR #135) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) # BigramHash: 2048 buckets (fits budget with 11L + seq2048; PR #135) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 64)) # BigramHash embedding dim (64=~0.5MB FP16, fits 16MB budget) + use_smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # SmearGate: blend token embedding with prev-token embedding (PR #135) + xsa_last_n = int(os.environ.get("XSA_LAST_N", "4")) # XSA: apply Exclusive Self Attention to deepest N layers (PR #265; 4 confirmed neutral vs 3 with seq2048+EMA) + # EMA: exponential moving average of weights (PR #287 = 1.1280 BPB uses decay=0.997). + # At each step: ema = decay * ema + (1-decay) * param. Use EMA weights for quantization. + # Different from SWA (dead end): EMA tracks all steps with exponential decay, stays near trained model. + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, normalize_rows: bool = False, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, normalize_rows=normalize_rows, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + normalize_rows = group["normalize_rows"] + weight_decay = group.get("weight_decay", 0.0) + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + # NorMuon: per-row second-moment normalization (arxiv 2510.05491) + if normalize_rows and g.ndim == 2: + if "row_v" not in state: + state["row_v"] = torch.zeros(g.size(0), device=g.device, dtype=torch.float32) + row_sq = g.float().pow(2).mean(dim=1) # [rows] squared Frobenius per row / cols + state["row_v"].mul_(0.95).add_(row_sq, alpha=0.05) + row_scale = state["row_v"].rsqrt().clamp(max=1e4).to(g.dtype) + g = g * row_scale.unsqueeze(1) + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if weight_decay > 0.0: + p.mul_(1.0 - lr * weight_decay) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window evaluation: gives each predicted token ~(seq_len - stride) tokens of context. + + Only applied for the final quantized eval (not during training). + stride=64 means each window slides by 64, scoring only the last 64 tokens. + Tokens in positions 0..seq_len-stride-1 are used as context only. + """ + seq_len = args.train_seq_len + stride = args.eval_stride + if stride >= seq_len: + # Fall back to standard non-overlapping eval + raise ValueError(f"eval_stride={stride} >= seq_len={seq_len}, use eval_val instead") + + # Total scoreable positions: all positions from index seq_len-1 to end of tokens + # Window i covers tokens[i*stride : i*stride+seq_len], scores tokens[i*stride+seq_len-stride : i*stride+seq_len] + total_tokens = val_tokens.numel() + # Number of full windows + num_windows = (total_tokens - seq_len) // stride + # Divide windows across ranks + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Process windows in batches of 32 for memory efficiency + sw_batch_size = 32 + + base_model.eval() + with torch.inference_mode(): + for batch_start in range(win_start, win_end, sw_batch_size): + batch_end = min(batch_start + sw_batch_size, win_end) + # Build input batch: each row is a full seq_len window + windows = [] + for w in range(batch_start, batch_end): + tok_start = w * stride + tok_end = tok_start + seq_len + 1 # +1 for targets + windows.append(val_tokens[tok_start:tok_end]) + batch = torch.stack(windows).to(device=device, dtype=torch.int64, non_blocking=True) + x = batch[:, :-1] # [B, seq_len] + y = batch[:, 1:] # [B, seq_len] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.get_logits(x) # [B, seq_len, V] + # Score only the last `stride` positions (they have full context) + logits_tail = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, V] + y_tail = y[:, -stride:].reshape(-1) # [B*stride] + token_losses = F.cross_entropy(logits_tail.float(), y_tail, reduction="none") # [B*stride] + val_loss_sum += token_losses.to(torch.float64).sum() + val_token_count += float(token_losses.numel()) + # BPB byte count for scored tokens + prev_tail = x[:, -stride:].reshape(-1) + token_bytes = base_bytes_lut[y_tail].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[y_tail] & ~is_boundary_token_lut[prev_tail]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# LORA TTT SUPPORT +# ----------------------------- +# Per-sequence test-time training with LoRA adapters on c_q, c_v in each layer. +# Adapters are initialized fresh (A=zeros, B=rand*0.01) and excluded from the artifact. +# At eval: for each val sequence, adapt on prefix tokens (1 Adam step), score suffix tokens. + +class LoRALinear(nn.Module): + """Wraps a frozen linear layer with a trainable LoRA delta: W_eff = W + B @ A.""" + def __init__(self, base_linear: nn.Linear, rank: int, device: torch.device): + super().__init__() + out_dim, in_dim = base_linear.weight.shape + self.base = base_linear + # Standard LoRA init: A~N(0, 0.01), B=0 → initial delta=0; B gets nonzero gradient after 1 step + self.lora_A = nn.Parameter(torch.randn(rank, in_dim, device=device, dtype=torch.float32) * 0.01) + self.lora_B = nn.Parameter(torch.zeros(out_dim, rank, device=device, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + base_out = self.base(x) + # LoRA delta: x @ A^T @ B^T (in bfloat16 context, keep A/B in float32 for stability) + lora_out = F.linear(F.linear(x.float(), self.lora_A), self.lora_B).to(x.dtype) + return base_out + lora_out + + +def apply_lora_ttt_to_model(model: nn.Module, rank: int, device: torch.device) -> list[LoRALinear]: + """Replace c_q, c_k, c_v, proj in each attention block with LoRALinear wrappers.""" + lora_wrappers = [] + for block in model.blocks: + attn = block.attn + lora_q = LoRALinear(attn.c_q, rank, device) + lora_k = LoRALinear(attn.c_k, rank, device) + lora_v = LoRALinear(attn.c_v, rank, device) + lora_p = LoRALinear(attn.proj, rank, device) + attn.c_q = lora_q + attn.c_k = lora_k + attn.c_v = lora_v + attn.proj = lora_p + lora_wrappers.extend([lora_q, lora_k, lora_v, lora_p]) + return lora_wrappers + + +def reset_lora_adapters(lora_wrappers: list[LoRALinear]) -> None: + """Re-initialize LoRA adapters for a fresh sequence. Standard LoRA: A~N(0,0.01), B=0.""" + with torch.no_grad(): + for lw in lora_wrappers: + nn.init.normal_(lw.lora_A, mean=0.0, std=0.01) + lw.lora_B.zero_() + + +def eval_val_sliding_lora_ttt( + args: "Hyperparameters", + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Document-level LoRA TTT eval. Fits in 10-minute eval budget. + + Treats each doc_len-token chunk as a document (~30K docs from 62M tokens). + For each document: + 1. Reset LoRA adapters (A~N(0,0.01), B=0) + 2. 1 Adam step on first prefix_len tokens (adaptation) + 3. Score remaining tokens with stride=64 sliding window using adapted model + Total TTT steps = num_docs / world_size ≈ ~3750 per GPU — completes in ~30s. + Score-before-adapt within document prevents data leakage (prefix ≠ scored tokens). + Adapters NOT saved to artifact. + """ + seq_len = args.train_seq_len + stride = args.eval_stride + lora_rank = args.lora_ttt_rank + lora_lr = args.lora_ttt_lr + prefix_len = args.lora_ttt_prefix_len # tokens for adaptation (default 512) + ttt_steps = args.lora_ttt_steps # Adam steps per document (default 1) + doc_len = seq_len * 2 # 2048: treat as one document + + base_model.eval() + lora_wrappers = apply_lora_ttt_to_model(base_model, lora_rank, device) + lora_params = [lw.lora_A for lw in lora_wrappers] + [lw.lora_B for lw in lora_wrappers] + lora_set = set(id(p) for p in lora_params) + for p in base_model.parameters(): + p.requires_grad_(id(p) in lora_set) + + total_tokens = val_tokens.numel() + num_docs = total_tokens // doc_len + doc_start = (num_docs * rank) // world_size + doc_end = (num_docs * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + val_tokens_gpu = val_tokens.to(device=device, dtype=torch.int64) + lora_opt = torch.optim.Adam(lora_params, lr=lora_lr, betas=(0.9, 0.95), eps=1e-8) + + for doc_idx in range(doc_start, doc_end): + tok_off = doc_idx * doc_len + + # ── ADAPT: ttt_steps Adam steps on prefix tokens ── + if prefix_len > 0: + x_pre = val_tokens_gpu[tok_off : tok_off + prefix_len].unsqueeze(0) + y_pre = val_tokens_gpu[tok_off + 1 : tok_off + prefix_len + 1].unsqueeze(0) + reset_lora_adapters(lora_wrappers) + for _ in range(ttt_steps): + lora_opt.zero_grad() + with torch.enable_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + adapt_loss = base_model(x_pre, y_pre) + adapt_loss.backward() + lora_opt.step() + + # ── SCORE: batched sliding window over document suffix with adapted LoRA ── + first_w = max(0, (prefix_len - seq_len + stride) // stride) + last_w = doc_len // stride # exclusive + if first_w >= last_w: + continue + # Filter to windows where both x and y are fully within bounds + n_total = val_tokens_gpu.numel() + win_indices = [ + w for w in range(first_w, last_w) + if tok_off + w * stride + seq_len + 1 <= n_total + ] + if not win_indices: + continue + xs = torch.stack([ + val_tokens_gpu[tok_off + w * stride : tok_off + w * stride + seq_len] + for w in win_indices + ]) # [W, seq_len] + ys = torch.stack([ + val_tokens_gpu[tok_off + w * stride + 1 : tok_off + w * stride + seq_len + 1] + for w in win_indices + ]) # [W, seq_len] + W = xs.size(0) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits_batch = base_model.get_logits(xs) # [W, seq_len, V] + logits_tail = logits_batch[:, -stride:, :].float() # [W, stride, V] + y_tail = ys[:, -stride:].reshape(-1) # [W*stride] + prev_tail = xs[:, -stride:].reshape(-1) # [W*stride] + token_losses = F.cross_entropy( + logits_tail.reshape(W * stride, -1), y_tail, reduction="none" + ) + val_loss_sum += token_losses.to(torch.float64).sum() + val_token_count += W * stride + token_bytes = base_bytes_lut[y_tail].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[y_tail] & ~is_boundary_token_lut[prev_tail]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.eval() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# INT6 QAT SUPPORT +# ----------------------------- +# Straight-Through Estimator (STE) fake quantization for Int6 (64 levels, -31..+31). +# Applied during forward pass of CastedLinear when USE_INT6_QAT=1. +# Gradient flows through unchanged (STE); weights get rounded to grid during forward only. + +INT6_LEVELS = 31 # symmetric: -31..+31 gives 63 levels (+0), same as int6 range +INT5_LEVELS = 15 # symmetric: -16..+15 gives 32 levels, stored as int8 but compresses better (PR #264) + +def fake_quantize_ste(w: torch.Tensor) -> torch.Tensor: + """Per-row fake-quantize weight to Int6 range. STE: gradient passes through unchanged.""" + w32 = w.float() + scale = w32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) / INT6_LEVELS + w_q = (w32 / scale).round().clamp(-INT6_LEVELS, INT6_LEVELS) * scale + # STE: in forward use quantized value; in backward gradient passes straight through + return w + (w_q - w).detach() + +def fake_quantize_ste_int5(w: torch.Tensor) -> torch.Tensor: + """Per-row fake-quantize weight to Int5 range. STE: gradient passes through unchanged. + Values clamped to [-16, 15] (32 levels) — stored as int8 but compresses ~15% better (PR #264).""" + w32 = w.float() + scale = w32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) / INT5_LEVELS + w_q = (w32 / scale).round().clamp(-INT5_LEVELS - 1, INT5_LEVELS) * scale + return w + (w_q - w).detach() + + +# Global flag (set in main after args are parsed) +_USE_INT6_QAT: bool = False +# Global flag: use int5 for MLP weights (c_fc and c_proj) instead of int6 (PR #264) +_USE_INT5_MLP_QAT: bool = False +_INT5_MLP_WEIGHT_NAMES: set[str] = set() + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,bigram", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +# When KEEP_EMBED_FP16=1 (default), keep tok_emb.weight as fp16 instead of int8 for ~-0.006 BPB improvement. +INT8_KEEP_EMBED_FP16 = bool(int(os.environ.get("KEEP_EMBED_FP16", "1"))) +INT8_KEEP_EMBED_FP16_PATTERNS = ("tok_emb",) if INT8_KEEP_EMBED_FP16 else () +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: + """Per-row Int6 quantization: symmetric, 63 levels (-31..+31), stored as int8 container. + Saves ~25% vs int8 because scale range is 31 vs 127, zlib compresses residuals better. + Only applied to large 2D matrices (attention + MLP weights, not embeddings). + """ + t32 = t.float() + assert t32.ndim == 2, "int6 only for 2D matrices" + scale = t32.abs().amax(dim=1).clamp(min=1e-8) / INT6_LEVELS + q = (t32 / scale[:, None]).round().clamp(-INT6_LEVELS, INT6_LEVELS).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_float_tensor_int5(t: Tensor) -> tuple[Tensor, Tensor]: + """Per-row Int5 quantization: values in [-16, 15] (32 levels), stored as int8 container. + Compresses ~15% better than int6 via smaller value range. Applied to MLP weights (PR #264). + Trades ~0.001-0.003 BPB quality loss for smaller compressed artifact.""" + t32 = t.float() + assert t32.ndim == 2, "int5 only for 2D matrices" + scale = t32.abs().amax(dim=1).clamp(min=1e-8) / INT5_LEVELS + q = (t32 / scale[:, None]).round().clamp(-INT5_LEVELS - 1, INT5_LEVELS).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +# Global flag: use int6 for large matrix weights in the export (set from args.use_int6_qat) +_EXPORT_INT6: bool = False +# Names of weight tensors that should use int6 (populated after model is created in main()) +_INT6_WEIGHT_NAMES: set[str] = set() +# Names of c_k weights in final N layers to keep as FP16 (Late-K FP16 from PR #114) +_LATE_K_FP16_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # FP16 embedding: keep tok_emb.weight as fp16 to reduce quantization error. + if INT8_KEEP_EMBED_FP16_PATTERNS and any(p in name for p in INT8_KEEP_EMBED_FP16_PATTERNS): + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Late-K FP16: keep c_k weights in final N layers as fp16 (not int6/int8). + if _LATE_K_FP16_NAMES and name in _LATE_K_FP16_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Int5 export: use 5-bit quantization for MLP weight matrices (PR #264, compresses better) + if _INT5_MLP_WEIGHT_NAMES and t.ndim == 2 and name in _INT5_MLP_WEIGHT_NAMES: + q, s = quantize_float_tensor_int5(t) + qmeta[name] = {"scheme": "int5_per_row", "axis": 0, "levels": INT5_LEVELS} + # Int6 export: use 6-bit quantization for large attention+MLP weight matrices + elif _EXPORT_INT6 and t.ndim == 2 and name in _INT6_WEIGHT_NAMES: + q, s = quantize_float_tensor_int6(t) + qmeta[name] = {"scheme": "int6_per_row", "axis": 0, "levels": INT6_LEVELS} + else: + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table. + Injects bigram-level info into the residual stream at negligible compute cost. + Inspired by modded-nanogpt record #62. From PR #135 (1.1539 BPB).""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod # BOS position + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # When _USE_INT6_QAT is enabled, fake-quantize weights to int6 grid before each matmul. + # _skip_int6: set True on c_k layers in final N layers (Late-K FP16) to keep them at higher precision. + # _use_int5: set True on MLP layers (c_fc/c_proj) when USE_INT5_MLP=1 (PR #264). + _skip_int6: bool = False + _use_int5: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if _USE_INT6_QAT and w.ndim == 2 and w.numel() > 65536 and not self._skip_int6: + if _USE_INT5_MLP_QAT and self._use_int5: + w = fake_quantize_ste_int5(w) + else: + w = fake_quantize_ste(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False # set True on deep layers by GPT.__init__ (PR #265) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape. + y: [B, T, H, D], v: [B, T, Hkv, D]. No repeat_interleave — free view. (PR #265)""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2) # [B, T, H, D] + if self.use_xsa: + v_t = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa_efficient(y, v_t) + y = y.contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ortho_init: bool = False, + bigram_vocab_size: int = 0, + bigram_dim: int = 64, + use_smear_gate: bool = False, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.ortho_init = ortho_init + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self._num_layers = num_layers + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + # SmearGate: learned per-dim gate blending current token with prev token (PR #135) + self.smear_gate = nn.Parameter(torch.zeros(model_dim)) if use_smear_gate else None + # XSA: enable Exclusive Self Attention on the deepest xsa_last_n layers (PR #265) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + # μP-style: scale output projections by 1/sqrt(2*num_layers) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * self._num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + # SmearGate: blend current token embedding with previous token (PR #135) + if self.smear_gate is not None: + g = torch.sigmoid(self.smear_gate.to(dtype=x.dtype)) + x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) + x = (1.0 - g) * x + g * x_prev + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def get_logits(self, input_ids: Tensor) -> Tensor: + """Return per-token logits [B, T, V] without computing loss. Used for sliding window eval.""" + B, T = input_ids.shape + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + # SmearGate: blend current token embedding with previous token (PR #135) + if self.smear_gate is not None: + g = torch.sigmoid(self.smear_gate.to(dtype=x.dtype)) + x_prev = torch.cat([torch.zeros_like(x[:, :1, :]), x[:, :-1, :]], dim=1) + x = (1.0 - g) * x + g * x_prev + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) # [B, T, D] + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) # [B, T, V] + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(True) # FA3 via cuDNN: faster on H100 (~49ms/step vs 53ms for flash) + enable_flash_sdp(False) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ortho_init=args.ortho_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + use_smear_gate=args.use_smear_gate, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Int6 QAT: enable fake-quantization in CastedLinear.forward() and collect weight names for export. + global _USE_INT6_QAT, _EXPORT_INT6, _INT6_WEIGHT_NAMES, _LATE_K_FP16_NAMES + if args.use_int6_qat: + _USE_INT6_QAT = True + _EXPORT_INT6 = True + # Late-K FP16: identify c_k weight names in the final N layers to keep as FP16. + # These skip fake-quantization during training and are stored as FP16 (not int6) at export. + _LATE_K_FP16_NAMES = set() + if args.late_k_fp16_layers > 0: + num_layers = args.num_layers + late_start = num_layers - args.late_k_fp16_layers + for name, p in base_model.named_parameters(): + if ".attn.c_k.weight" in name: + # name format: blocks.N.attn.c_k.weight + parts = name.split(".") + try: + layer_idx = int(parts[1]) + if layer_idx >= late_start: + _LATE_K_FP16_NAMES.add(name) + except (ValueError, IndexError): + pass + # Set _skip_int6 on the corresponding CastedLinear modules + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + weight_name = name + ".weight" + if weight_name in _LATE_K_FP16_NAMES: + module._skip_int6 = True + log0(f"late_k_fp16:enabled layers:{args.late_k_fp16_layers} weight_tensors:{len(_LATE_K_FP16_NAMES)}") + # Collect names of large 2D weight tensors in attention + MLP (not embeddings, not late-K FP16). + _INT6_WEIGHT_NAMES = { + name + for name, p in base_model.named_parameters() + if p.ndim == 2 + and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL + and not any(pat in name for pat in INT8_KEEP_EMBED_FP16_PATTERNS + INT8_KEEP_FLOAT_FP32_NAME_PATTERNS) + and name not in _LATE_K_FP16_NAMES + } + log0(f"int6_qat:enabled weight_tensors:{len(_INT6_WEIGHT_NAMES)}") + + # Int5-MLP QAT: use 5-bit quantization for MLP weights (c_fc, c_proj) — PR #264 + global _USE_INT5_MLP_QAT, _INT5_MLP_WEIGHT_NAMES + if args.use_int6_qat and args.use_int5_mlp: + _USE_INT5_MLP_QAT = True + # Identify MLP weight names (mlp.fc.weight and mlp.proj.weight for all blocks) + _INT5_MLP_WEIGHT_NAMES = { + name + for name, p in base_model.named_parameters() + if p.ndim == 2 + and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL + and (".mlp.fc." in name or ".mlp.proj." in name) + and name not in _LATE_K_FP16_NAMES + } + # Remove int5-MLP names from int6 set (they get int5 instead) + _INT6_WEIGHT_NAMES -= _INT5_MLP_WEIGHT_NAMES + # Set _use_int5=True on the corresponding CastedLinear modules + for mod_name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + if (mod_name + ".weight") in _INT5_MLP_WEIGHT_NAMES: + module._use_int5 = True + log0(f"int5_mlp:enabled weight_tensors:{len(_INT5_MLP_WEIGHT_NAMES)}") + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + # BigramHash embed weight goes to AdamW at tied_embed_lr; proj weight goes to Muon + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + scalar_params.append(base_model.bigram.scale) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + # SmearGate: 1D param, train with Adam at scalar_lr + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate) + optimizer_tok = torch.optim.Adam( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + normalize_rows=args.normuon, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0(f"normuon:{args.normuon} eval_stride:{args.eval_stride} warmdown_iters:{args.warmdown_iters} warmdown_lr_floor:{args.warmdown_lr_floor} keep_embed_fp16:{INT8_KEEP_EMBED_FP16} swa_every:{args.swa_every} use_int6_qat:{args.use_int6_qat} mlp_mult:{args.mlp_mult} late_k_fp16_layers:{args.late_k_fp16_layers} lora_ttt:{args.lora_ttt} lora_ttt_rank:{args.lora_ttt_rank} lora_ttt_lr:{args.lora_ttt_lr} lora_ttt_prefix_len:{args.lora_ttt_prefix_len}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + floor = args.warmdown_lr_floor + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + if warmdown_start <= step < args.iterations: + frac = (args.iterations - step) / max(args.warmdown_iters, 1) + return floor + (1.0 - floor) * frac + return 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + if remaining_ms <= warmdown_ms: + frac = remaining_ms / max(warmdown_ms, 1e-9) + return floor + (1.0 - floor) * frac + return 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + # SWA state: accumulate model params during warmdown phase. + swa_param_sums: dict[str, Tensor] | None = None + swa_count = 0 + swa_last_collect_step = -1 + + # EMA state: exponential moving average of all model parameters (PR #287 decay=0.997). + # Updated every step after optimizer.step(). Weights copied to model before quantization. + ema_params: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_params = {n: p.detach().float().clone() for n, p in base_model.named_parameters()} + log0(f"ema:enabled decay:{args.ema_decay} params:{sum(p.numel() for p in ema_params.values())}") + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA: update exponential moving average after each optimizer step. + if ema_params is not None: + decay = args.ema_decay + with torch.no_grad(): + for n, p in base_model.named_parameters(): + ema_params[n].mul_(decay).add_(p.detach().float(), alpha=1.0 - decay) + + # SWA: collect params during warmdown phase every swa_every steps. + if args.swa_every > 0 and scale < 1.0 and (step - swa_last_collect_step) >= args.swa_every: + if swa_param_sums is None: + swa_param_sums = {n: p.detach().float().clone() for n, p in base_model.named_parameters()} + else: + for n, p in base_model.named_parameters(): + swa_param_sums[n].add_(p.detach().float()) + swa_count += 1 + swa_last_collect_step = step + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # SWA: if we collected any checkpoints, replace model weights with their average. + if swa_param_sums is not None and swa_count > 1: + log0(f"swa: averaging {swa_count} checkpoints collected during warmdown") + with torch.no_grad(): + for n, p in base_model.named_parameters(): + avg = swa_param_sums[n].div_(swa_count).to(dtype=p.dtype) + p.copy_(avg) + # Sync averaged weights across ranks. + if distributed: + for p in base_model.parameters(): + dist.broadcast(p.data, src=0) + elif args.swa_every > 0: + log0(f"swa: only {swa_count} checkpoints collected, skipping averaging (need >1)") + + # EMA: replace model weights with EMA weights before quantization (PR #287). + # EMA provides smoother weights than the final SGD/Muon iterate, improving QAT compression. + if ema_params is not None: + log0(f"ema: loading EMA weights (decay={args.ema_decay}) into model before quantization") + with torch.no_grad(): + for n, p in base_model.named_parameters(): + p.copy_(ema_params[n].to(dtype=p.dtype)) + # Sync EMA weights across ranks (all ranks must use same weights for consistent quant). + if distributed: + for p in base_model.parameters(): + dist.broadcast(p.data, src=0) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.lora_ttt: + log0(f"lora_ttt:enabled rank:{args.lora_ttt_rank} lr:{args.lora_ttt_lr} prefix_len:{args.lora_ttt_prefix_len} steps:{args.lora_ttt_steps} targets:c_q,c_k,c_v,proj") + q_val_loss, q_val_bpb = eval_val_sliding_lora_ttt( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + elif args.eval_stride < args.train_seq_len: + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()