diff --git a/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/README.md b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/README.md new file mode 100644 index 0000000000..6a3ab83020 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/README.md @@ -0,0 +1,56 @@ +# Depth-Recurrent Transformer with Competitive Recipe + +**Author:** Brian Mwai (@brn-mwai) +**Score:** TBD (pending 8xH100 validation) + +## Architecture + +Two modes, controlled by `DEPTH_RECURRENCE` env var: + +### Mode A: Standard (11 unique layers) +- 11 layers, 512 dim, 3x MLP (hidden=1536) +- 8 attention heads, 4 KV heads (GQA) +- ~27M parameters + +### Mode B: Depth Recurrence +- 2 prelude + 1 shared (looped 7x) + 2 coda = 5 unique blocks, 11 effective +- Iteration embeddings tell shared block which pass it's on +- Freed parameters can go to wider model (640+ dim) + +## Techniques + +| Technique | Impact | Source | +|-----------|--------|--------| +| Int6 quantization + zstd-22 | ~30% size savings vs int8+zlib | Competition meta | +| 11 layers, MLP 3x | Funded by int6 savings | PR #162 | +| SmearGate | ~0.005 BPB from token blending | PR #135 | +| BigramHash | Token-pair context embeddings | PR #162 | +| Muon weight decay (0.03) | Quant-friendly weight distributions | PR #179 | +| SWA (last 50%) | Smoother weights, better quant | PR #162 | +| Sliding window eval (stride=64) | ~0.03 BPB free improvement | PR #56 | +| Orthogonal init | Better training dynamics | PR #135 | +| FP16 embedding passthrough | Most quant-sensitive tensor | PR #42 | +| Depth recurrence (optional) | Novel - share blocks, widen model | Original | + +## Usage + +Standard mode: +```bash +RUN_ID=standard_v1 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Depth recurrence mode: +```bash +RUN_ID=depth_recur_v1 \ +DEPTH_RECURRENCE=1 \ +MODEL_DIM=640 \ +NUM_PRELUDE=2 \ +NUM_CODA=2 \ +RECURRENT_LOOPS=7 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Theoretical Motivation + +From MDL (Minimum Description Length) theory: a model with 3 unique layers looped 3 times has lower L(model) than 9 unique layers. The freed description bits go to better L(data|model). Combined with the competitive recipe, this should push BPB below the current frontier. diff --git a/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/submission.json b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/submission.json new file mode 100644 index 0000000000..db436d59ea --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/submission.json @@ -0,0 +1,26 @@ +{ + "author": "Brian Mwai", + "github_id": "brn-mwai", + "val_bpb": null, + "artifact_bytes": null, + "code_bytes": null, + "total_bytes": null, + "gpu_config": "8xH100", + "wallclock_seconds": null, + "training_steps": null, + "techniques": [ + "int6_quantization", + "zstd22_compression", + "11_layers", + "mlp_3x", + "smear_gate", + "bigram_hash", + "muon_weight_decay", + "stochastic_weight_averaging", + "sliding_window_eval", + "orthogonal_init", + "fp16_embedding_passthrough", + "depth_recurrence_optional" + ], + "notes": "Depth-recurrent transformer with full competitive recipe. Optional depth recurrence mode shares blocks to free parameter budget for wider models." +} diff --git a/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/train_gpt.py b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/train_gpt.py new file mode 100644 index 0000000000..d4a9e72f6c --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_BrnMwai_DepthRecurrence/train_gpt.py @@ -0,0 +1,1289 @@ +""" +Parameter Golf Submission: Brian Mwai (@brn-mwai) +Depth-Recurrent Transformer with Int6 + SmearGate + BigramHash + SWA + +Key innovations over baseline: +- Int6 per-row quantization + zstd-22 compression (saves ~30% vs int8+zlib) +- 11 layers with 3x MLP expansion (funded by int6 savings) +- SmearGate: learned token-blending gate per layer (~512 params/layer) +- BigramHash: hash-table bigram embeddings (~524K params) +- Muon weight decay for quantization-friendly weight distributions +- Stochastic Weight Averaging (SWA) over last 50% of training +- Sliding window evaluation (stride=64) for ~0.03 BPB free improvement +- Orthogonal initialization + muP-style output scaling +- FP16 embedding passthrough (most quant-sensitive tensor) +- Optional depth recurrence mode (shared blocks looped N times) +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import struct +import subprocess +import sys +import time +import uuid +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Compression backend: prefer zstandard, fallback to zlib +try: + import zstandard + def compress_bytes(data: bytes) -> bytes: + return zstandard.ZstdCompressor(level=22).compress(data) + def decompress_bytes(data: bytes) -> bytes: + return zstandard.ZstdDecompressor().decompress(data) + COMPRESS_TAG = "zstd22" +except ImportError: + import zlib + def compress_bytes(data: bytes) -> bytes: + return zlib.compress(data, level=9) + def decompress_bytes(data: bytes) -> bytes: + return zlib.decompress(data) + COMPRESS_TAG = "zlib9" + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape - competitive defaults + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # New competitive hyperparameters + muon_weight_decay = float(os.environ.get("MUON_WD", 0.03)) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_interval = int(os.environ.get("SWA_INTERVAL", 200)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + use_smear_gate = bool(int(os.environ.get("USE_SMEAR_GATE", "1"))) + use_bigram_hash = bool(int(os.environ.get("USE_BIGRAM_HASH", "1"))) + bigram_table_size = int(os.environ.get("BIGRAM_TABLE_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Depth recurrence config + depth_recurrence = bool(int(os.environ.get("DEPTH_RECURRENCE", "0"))) + num_prelude = int(os.environ.get("NUM_PRELUDE", 2)) + num_coda = int(os.environ.get("NUM_CODA", 2)) + recurrent_loops = int(os.environ.get("RECURRENT_LOOPS", 7)) + +# ----------------------------- +# INT6 QUANTIZATION +# ----------------------------- + +INT6_LEVELS = 63 # [-31, 32] mapped to [0, 63] +INT6_CLIP_PERCENTILE = 99.99984 +INT6_CLIP_Q = INT6_CLIP_PERCENTILE / 100.0 +INT6_PER_ROW_SCALE_DTYPE = torch.float16 + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear_gate,bigram", + ).split(",") if p +) +INT6_KEEP_FLOAT_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT6_KEEP_FLOAT_STORE_DTYPE = torch.float16 + +# Embedding names to keep as FP16 passthrough (most quantization-sensitive) +EMBED_PASSTHROUGH_PATTERNS = ("tok_emb",) + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def pack_int6(values: Tensor) -> bytes: + """Pack int6 values [0, 63] into bytes (4 values -> 3 bytes). Vectorized.""" + flat = values.flatten().numpy().astype(np.uint8) + n = len(flat) + # Pad to multiple of 4 + pad = (4 - n % 4) % 4 + if pad: + flat = np.concatenate([flat, np.zeros(pad, dtype=np.uint8)]) + # Vectorized packing: reshape to (N/4, 4) then bitwise ops + grouped = flat.reshape(-1, 4) + a, b, c, d = grouped[:, 0], grouped[:, 1], grouped[:, 2], grouped[:, 3] + b0 = (a << 2) | (b >> 4) + b1 = ((b & 0x0F) << 4) | (c >> 2) + b2 = ((c & 0x03) << 6) | d + result = np.stack([b0, b1, b2], axis=1).flatten().astype(np.uint8) + return struct.pack(" Tensor: + """Unpack bytes back to int6 values [0, 63]. Vectorized.""" + n = struct.unpack("> 2 + v1 = ((r0 & 0x03) << 4) | (r1 >> 4) + v2 = ((r1 & 0x0F) << 2) | (r2 >> 6) + v3 = r2 & 0x3F + result = np.stack([v0, v1, v2, v3], axis=1).flatten().astype(np.uint8) + return torch.from_numpy(result[:n].copy()).to(torch.uint8) + + +def quantize_tensor_int6(t: Tensor) -> tuple[bytes, Tensor]: + """Quantize a float tensor to int6 with per-row or per-tensor scales.""" + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + clipped = torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None]) + scale = (clip_abs / 31.0).clamp_min(1.0 / 31.0) + q = torch.clamp(torch.round(clipped / scale[:, None]) + 32, 0, 63).to(torch.uint8) + return pack_int6(q), scale.to(INT6_PER_ROW_SCALE_DTYPE).contiguous() + else: + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 31.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale) + 32, 0, 63).to(torch.uint8) + return pack_int6(q), scale + + +def dequantize_tensor_int6(packed: bytes, scale: Tensor, shape: tuple, dtype: torch.dtype) -> Tensor: + """Dequantize int6 packed bytes back to float tensor.""" + q = unpack_int6(packed).float() - 32.0 # Back to [-32, 31] + q = q.reshape(shape) + if scale.ndim > 0: + return (q * scale.float().view(shape[0], *([1] * (len(shape) - 1)))).to(dtype) + else: + return (q * scale.float().item()).to(dtype) + + +def quantize_state_dict_int6(state_dict: dict[str, Tensor]): + """Quantize model state dict using int6 with FP16 embedding passthrough.""" + quantized_packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + shapes: dict[str, tuple] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = {"param_count": 0, "num_tensors": 0} + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + + # FP16 passthrough for embeddings + if any(p in name for p in EMBED_PASSTHROUGH_PATTERNS): + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = t.to(INT6_KEEP_FLOAT_STORE_DTYPE).contiguous() + else: + passthrough[name] = t + continue + + if not t.is_floating_point(): + passthrough[name] = t + continue + + # Small/control tensors kept as fp16 + if t.numel() <= INT6_KEEP_FLOAT_MAX_NUMEL or any(p in name for p in INT6_KEEP_FLOAT_NAME_PATTERNS): + if any(p in name for p in INT6_KEEP_FLOAT_NAME_PATTERNS): + passthrough[name] = t.to(INT6_KEEP_FLOAT_STORE_DTYPE).contiguous() + elif t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = t.to(INT6_KEEP_FLOAT_STORE_DTYPE).contiguous() + else: + passthrough[name] = t + continue + + # Int6 quantization for large float tensors + packed, s = quantize_tensor_int6(t) + quantized_packed[name] = packed + scales[name] = s + shapes[name] = tuple(t.shape) + dtypes[name] = str(t.dtype).removeprefix("torch.") + + obj = { + "__quant_format__": "int6_per_row_v1", + "quantized_packed": quantized_packed, + "scales": scales, + "shapes": shapes, + "dtypes": dtypes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int6(obj: dict) -> dict[str, Tensor]: + """Reconstruct state dict from int6 quantized format.""" + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + + for name in obj["quantized_packed"]: + packed = obj["quantized_packed"][name] + scale = obj["scales"][name] + shape = obj["shapes"][name] + dtype = getattr(torch, obj["dtypes"][name]) + out[name] = dequantize_tensor_int6(packed, scale, shape, dtype).contiguous() + + for name, t in obj["passthrough"].items(): + out_t = t.detach().cpu().contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + + return out + + +# ----------------------------- +# MUON OPTIMIZER (with weight decay) +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + weight_decay = group["weight_decay"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + # Decoupled weight decay + if weight_decay > 0: + p.mul_(1.0 - lr * weight_decay) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window evaluation for better BPB.""" + stride = args.eval_stride + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Sliding window: each window is seq_len tokens, stride forward by `stride` + # Only score the last `stride` tokens (or all if stride >= seq_len) + score_len = min(stride, seq_len) + window_starts = list(range(0, total_tokens - seq_len, stride)) + + # Distribute windows across ranks + my_starts = window_starts[rank::world_size] + + model.eval() + with torch.inference_mode(): + for ws in my_starts: + x = val_tokens[ws : ws + seq_len].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[ws + 1 : ws + seq_len + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + # Get per-token loss for the scored positions + logits = model.module.forward_logits(x) if hasattr(model, 'module') else model.forward_logits(x) + # Only score last `score_len` tokens + logits_scored = logits[:, -score_len:, :].reshape(-1, logits.size(-1)) + targets_scored = y[:, -score_len:].reshape(-1) + loss = F.cross_entropy(logits_scored.float(), targets_scored, reduction="sum") + + val_loss_sum += loss.to(torch.float64) + val_token_count += score_len + + # Byte counting for BPB + prev_ids = x[:, -score_len:].reshape(-1) + tgt_ids = y[:, -score_len:].reshape(-1) + # Handle the prev token for BPB byte boundary check + if score_len < seq_len: + prev_ids_for_boundary = x[0, seq_len - score_len - 1 : seq_len - 1] + else: + prev_ids_for_boundary = F.pad(x[0, :-1], (1, 0)) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids_for_boundary]).to(torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# Fallback standard eval (non-sliding, for speed during training) +def eval_val_standard( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class SmearGate(nn.Module): + """Learned per-dimension gate blending current token with previous token.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.full((dim,), -2.0)) # Init sigmoid(-2) ~ 0.12 + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate).to(x.dtype)[None, None, :] + x_prev = F.pad(x[:, :-1, :], (0, 0, 1, 0)) + return (1 - g) * x + g * x_prev + + +class BigramHash(nn.Module): + """Hash-table bigram embeddings for token-pair context.""" + def __init__(self, table_size: int, embed_dim: int, model_dim: int): + super().__init__() + self.table = nn.Embedding(table_size, embed_dim) + self.proj = CastedLinear(embed_dim, model_dim, bias=False) if embed_dim != model_dim else None + self.table_size = table_size + nn.init.normal_(self.table.weight, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0)) + h = (prev.long() * 2654435761 + input_ids.long()) % self.table_size + x = self.table(h) + if self.proj is not None: + x = self.proj(x) + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Repeat KV heads for GQA compatibility (works on all PyTorch versions) + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + mlp_mult: int, rope_base: float, qk_gain_init: float, + use_smear_gate: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.smear = SmearGate(dim) if use_smear_gate else None + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + if self.smear is not None: + x = self.smear(x) + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + use_smear_gate: bool = False, + use_bigram_hash: bool = False, + bigram_table_size: int = 4096, + bigram_dim: int = 128, + depth_recurrence: bool = False, + num_prelude: int = 2, + num_coda: int = 2, + recurrent_loops: int = 7, + ): + super().__init__() + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.depth_recurrence = depth_recurrence + self.tok_emb = nn.Embedding(vocab_size, model_dim) + + # BigramHash + self.bigram_hash = BigramHash(bigram_table_size, bigram_dim, model_dim) if use_bigram_hash else None + + if depth_recurrence: + # Depth recurrence: prelude + shared (looped) + coda + self.num_prelude = num_prelude + self.num_coda = num_coda + self.recurrent_loops = recurrent_loops + num_unique = num_prelude + 1 + num_coda # 1 shared block + effective_depth = num_prelude + recurrent_loops + num_coda + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Create unique blocks + self.prelude_blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, use_smear_gate) + for _ in range(num_prelude) + ]) + self.shared_block = Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, use_smear_gate) + self.coda_blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, use_smear_gate) + for _ in range(num_coda) + ]) + # Iteration embeddings for the shared block + self.iter_embed = nn.Parameter(torch.randn(recurrent_loops, model_dim) * 0.01) + self.blocks = None # Not used in depth recurrence mode + else: + # Standard unique layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, use_smear_gate) + for _ in range(num_layers) + ]) + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + + # Orthogonal init for linear layers, zero init for output projections + for module in self.modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + else: + nn.init.orthogonal_(module.weight) + + def _get_effective_blocks(self) -> list: + """Return the sequence of (block, is_shared, iter_idx) for forward pass.""" + if not self.depth_recurrence: + return [(b, False, -1) for b in self.blocks] + result = [] + for b in self.prelude_blocks: + result.append((b, False, -1)) + for i in range(self.recurrent_loops): + result.append((self.shared_block, True, i)) + for b in self.coda_blocks: + result.append((b, False, -1)) + return result + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits (for sliding window eval).""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + blocks_info = self._get_effective_blocks() + effective_depth = len(blocks_info) + enc_layers = effective_depth // 2 + dec_layers = effective_depth - enc_layers + + for i in range(enc_layers): + block, is_shared, iter_idx = blocks_info[i] + if is_shared and iter_idx >= 0: + x = x + self.iter_embed[iter_idx].to(dtype=x.dtype)[None, None, :] + x = block(x, x0) + skips.append(x) + + skip_idx = 0 + for i in range(dec_layers): + if skips and skip_idx < self.num_skip_weights: + x = x + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + skip_idx += 1 + block, is_shared, iter_idx = blocks_info[enc_layers + i] + if is_shared and iter_idx >= 0: + x = x + self.iter_embed[iter_idx].to(dtype=x.dtype)[None, None, :] + x = block(x, x0) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # Distributed + CUDA setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + + # Tokenizer + validation setup + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece") + + # Model setup + model_kwargs = dict( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + use_smear_gate=args.use_smear_gate, + use_bigram_hash=args.use_bigram_hash, + bigram_table_size=args.bigram_table_size, + bigram_dim=args.bigram_dim, + depth_recurrence=args.depth_recurrence, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + recurrent_loops=args.recurrent_loops, + ) + base_model = GPT(**model_kwargs).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer setup + if base_model.blocks is not None: + block_named_params = list(base_model.blocks.named_parameters()) + else: + # Depth recurrence mode + block_named_params = ( + list(base_model.prelude_blocks.named_parameters()) + + list(base_model.shared_block.named_parameters()) + + list(base_model.coda_blocks.named_parameters()) + ) + + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, 'iter_embed'): + scalar_params.append(base_model.iter_embed) + + # BigramHash params go to Adam + bigram_params = [] + if base_model.bigram_hash is not None: + bigram_params = list(base_model.bigram_hash.parameters()) + # Remove from matrix/scalar if they ended up there + bigram_param_ids = {id(p) for p in bigram_params} + matrix_params = [p for p in matrix_params if id(p) not in bigram_param_ids] + scalar_params = [p for p in scalar_params if id(p) not in bigram_param_ids] + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizer_muon = Muon( + matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, optimizer_head) + + if bigram_params: + optimizer_bigram = torch.optim.Adam( + [{"params": bigram_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.append(optimizer_bigram) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"depth_recurrence:{args.depth_recurrence} smear_gate:{args.use_smear_gate} bigram_hash:{args.use_bigram_hash}") + log0(f"num_layers:{args.num_layers} model_dim:{args.model_dim} mlp_mult:{args.mlp_mult}") + log0(f"muon_wd:{args.muon_weight_decay} swa_start_frac:{args.swa_start_frac} eval_stride:{args.eval_stride}") + log0(f"seed:{args.seed}") + + # Data loader & model warmup + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup + if args.warmup_steps > 0: + initial_model_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # SWA state + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Main training loop + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + # During training, use fast standard eval; final eval uses sliding window + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val_standard( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # SWA collection + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + ms_per_step = approx_training_time_ms / max(step, 1) + if max_wallclock_ms: + estimated_total_steps = max_wallclock_ms / max(ms_per_step, 1.0) + else: + estimated_total_steps = float(args.iterations) + swa_start_step = int(estimated_total_steps * args.swa_start_frac) + if step >= swa_start_step and step % args.swa_interval == 0: + current_state = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current_state + swa_count = 1 + else: + swa_count += 1 + for k in swa_state: + swa_state[k].lerp_(current_state[k], 1.0 / swa_count) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # Apply SWA weights if collected + if swa_state is not None and swa_count > 0: + log0(f"swa: applying averaged weights from {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True) + + # Pre-quantization eval with sliding window + if master_process: + log0("running sliding window eval (pre-quant)...") + pre_q_loss, pre_q_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"pre_quant_sliding val_loss:{pre_q_loss:.4f} val_bpb:{pre_q_bpb:.4f}") + + # Serialization + roundtrip validation + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + log0(f"Raw model: {os.path.getsize('final_model.pt')} bytes") + + quant_obj, quant_stats = quantize_state_dict_int6(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_bytes(quant_raw) + if master_process: + artifact_name = f"final_model.int6.pt{COMPRESS_TAG}" + with open(artifact_name, "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + total_bytes = quant_file_bytes + code_bytes + log0(f"Int6+{COMPRESS_TAG}: model={quant_file_bytes} code={code_bytes} total={total_bytes}") + log0(f"size_check: {'PASS' if total_bytes <= 16_000_000 else 'FAIL'} ({total_bytes}/16000000)") + + # Roundtrip validation + if distributed: + dist.barrier() + artifact_name = f"final_model.int6.pt{COMPRESS_TAG}" + with open(artifact_name, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_bytes(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int6(quant_state), strict=True) + + torch.cuda.synchronize() + t_qeval = time.perf_counter() + + # Final sliding window eval after quantization roundtrip + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0(f"final_int6_{COMPRESS_TAG}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms") + log0(f"final_int6_{COMPRESS_TAG}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()