diff --git a/.gitignore b/.gitignore index 3423c416a7..f399f386bf 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,28 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/ +final_model* +*.ptz +*.ptzst +repo_dump.txt +flash4.pdf + +# Local tooling (not part of submission) +.claude/ +.mcp.json +CLAUDE.md +FINDINGS.md +modal_final_run.py +modal_smoke_test.py +modal_sweep.py +smoke_test.py +test_*.py +sweep_engine +sweep_engine.go +sweep_engine.py +sweep_engine_linux +sweep_v5.py +sweep_configs.py +sweep_results.json +dump_repo.py \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/README.md b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/README.md new file mode 100644 index 0000000000..48021d5157 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/README.md @@ -0,0 +1,87 @@ +# 11L + Hadamard Rotation + VE128 + cuDNN SDPA (val_bpb: 1.1364) + +## Key Innovation: Hadamard Rotation for Int6 Quantization + +Walsh-Hadamard rotation applied to weight matrices before int6 per-row quantization. The orthogonal rotation spreads outlier values uniformly across all dimensions, improving zstd compressibility from 1.70x to 1.77x while reducing the quantization gap from 0.0093 to 0.0084 BPB. This is the first application of rotation-based quantization (QuIP-family) in this competition. No other open or merged PR uses this technique. + +The 0.07x compression improvement translates to 530KB of recovered headroom within the 16MB artifact budget, directly enabling the addition of Shared Value Embeddings (VE128) which previously overflowed at 44KB headroom. + +Negative result: Full GPTQ (Hessian-calibrated quantization) provides zero additional benefit when combined with Hadamard rotation. The rotation already makes weight distributions sufficiently uniform for simple abs-max quantization. This confirms that Hadamard rotation and GPTQ are substitutes, not complements, at int6 precision. + +## Architecture + +- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA) +- 3x MLP expansion with relu-squared activation +- Exclusive Self-Attention (XSA) on last 4 layers (GQA-aware) +- Partial RoPE (16/64 dims) +- Layer-norm scale factor 1/sqrt(layer_idx+1) +- U-Net skip connections (5 encoder, 6 decoder) +- SmearGate + BigramHash (2048 buckets, inner_dim=128) +- Shared Value Embedding (dim=128, layers 9 and 10, per-layer learned scales) +- cuDNN SDPA backend (FlashAttention 3 conditional fallback) +- Logit softcap 30.0, tied embeddings +- Orthogonal init with projection scaling by 1/sqrt(2*num_layers) + +## Training + +- Muon optimizer (matrix params): lr=0.025, momentum=0.99 (warmup 0.92 to 0.99 over 1500 steps), WD=0.04 +- AdamW (embeddings): lr=0.035, (scalars): lr=0.025, WD=0.04 +- Gradient clip: 0.3 +- Batch: 524,288 tokens/step, seq_len=2048 +- Warmdown: 3500 iterations (wall-clock based, cosine schedule) +- EMA decay=0.997 (continuous, every step) + +## Quantization + +- Int6 per-row with Hadamard rotation (block-diagonal Walsh-Hadamard on column dimension) +- Abs-max scaling (no percentile clipping, no GPTQ) +- Control tensors (scales, gates, VE scales) in FP16 +- zstd level 22 compression + +## Ablation + +| Config | Sliding BPB | Compression | Headroom | Quant Gap | +|--------|-------------|-------------|----------|-----------| +| Baseline (S5-7, no Hadamard, no VE) | 1.1372 | 1.70x | 44KB | 0.0093 | +| + Hadamard rotation | 1.1377 | 1.78x | 712KB | 0.0091 | +| + VE128 (enabled by headroom) | **1.1364** | 1.77x | 530KB | 0.0084 | +| + GPTQ on top of Hadamard | 1.1401+ | 1.73x | 201KB | 0.0088 | + +Hadamard rotation enables VE128 by freeing 668KB of artifact headroom. GPTQ adds no value when Hadamard is present. + +## Methodology: CPU Parameter Probe + +Hyperparameter selection was guided by a CPU-based parameter sweep engine (Go, 80-core Modal) that estimates sliding BPB from architecture configuration without GPU training. The probe uses three independent estimation approaches (K-Nearest Neighbor, bounded parametric, relative delta) calibrated against 15 GPU training runs across sessions 2-7. + +This is a directional methodology, not a high-precision predictor. The probe narrows the search space by eliminating configurations that are unlikely to fit artifact constraints or improve BPB, reducing the number of expensive GPU runs needed. In our workflow, it guided the path from 9L to 11L MLP1536, correctly flagged FA3 compression overflow and 12L artifact limits, and identified the embed_lr=0.035 configuration that produced our best result -- all confirmed by subsequent GPU training. + +The approach is exploratory and specific to our calibration data. It demonstrates how lightweight CPU-based pre-screening can complement GPU experimentation when compute is constrained, and could be adapted to other search spaces with different calibration sets. + +## Additional Findings + +- **Late QAT was dead code in all prior work**: The `CastedLinear.qat_enabled` class attribute change was shadowed by instance attributes set during init. QAT-STE never activated during training in any session. Removing the dead QAT guard from CastedLinear.forward() eliminated a torch.compile guard, reducing step time from 70.6ms to 65.7ms (7% throughput gain). +- **FA3 compression penalty**: FlashAttention 3 produces weight distributions that compress 1.8% worse with zstd-22 (1.67x vs 1.70x cuDNN). The throughput gain (65ms vs 70.6ms) does not compensate at MLP1536 due to artifact overflow. +- **FP16 control params**: Storing scale/gate tensors as FP16 instead of FP32 is lossless for eval (values are cast to bfloat16 in forward). Saves 50KB raw payload. +- **INT6 bigram projection**: Quantizing the 128x512 BigramHash projection to int6 (vs FP16 passthrough) improves zstd compression by 0.09x. The quantization noise is negligible for this small embedding projection. + +## Results (3 seeds) + +| Seed | Steps | Pre-quant | Sliding BPB | Artifact Bytes | Compression | +|------|-------|-----------|-------------|----------------|-------------| +| 1337 | 8098 | 1.1512 | 1.1364 | 15,618,718 | 1.75x | +| 42 | 8102 | 1.1513 | 1.1361 | 15,629,540 | 1.75x | +| 2024 | 7960 | 1.1521 | 1.1370 | 15,600,361 | 1.76x | + +**Mean: 1.1365 +/- 0.0005 BPB.** All artifacts under 16MB. 27,038,810 parameters. + +## Run + +```bash +NUM_LAYERS=11 MLP_MULT=3 XSA_LAYERS=4 ROPE_DIMS=16 LN_SCALE=1 \ +VE_DIM=128 VE_LAYERS=9,10 EMA_DECAY=0.997 \ +MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ +WARMDOWN_ITERS=3500 USE_CUDNN_SDPA=1 \ +torchrun --nproc_per_node=8 train_gpt.py +``` + +Erick Aleman | EA Cognitive | www.eacognitive.com | github.com/eacognitive diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/submission.json b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/submission.json new file mode 100644 index 0000000000..e16c42641e --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/submission.json @@ -0,0 +1,18 @@ +{ + "name": "11L + Hadamard Rotation + VE128 + cuDNN SDPA", + "author": "Erick Aleman", + "github_id": "eacognitive", + "date": "2026-03-24", + "val_bpb": 1.1365, + "pre_quant_val_bpb": 1.1515, + "bytes_total": 15629540, + "bytes_code": 52010, + "seed_results": { + "1337": {"val_bpb": 1.1364, "steps": 8098, "artifact_bytes": 15618718}, + "42": {"val_bpb": 1.1361, "steps": 8102, "artifact_bytes": 15629540}, + "2024": {"val_bpb": 1.1370, "steps": 7960, "artifact_bytes": 15600361} + }, + "mean_bpb": 1.1365, + "std_bpb": 0.0005, + "blurb": "First application of Walsh-Hadamard rotation for int6 quantization in parameter-constrained LLMs. Improves zstd compression from 1.70x to 1.76x, recovering 530KB of artifact headroom that enables Shared Value Embeddings (VE128). Demonstrates that Hadamard rotation and GPTQ are substitutes at int6 precision." +} diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_gpt.py b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_gpt.py new file mode 100644 index 0000000000..cace894b63 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_gpt.py @@ -0,0 +1,1365 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard + def _compress(data: bytes) -> bytes: + return zstandard.ZstdCompressor(level=22).compress(data) + def _decompress(data: bytes) -> bytes: + return zstandard.ZstdDecompressor().decompress(data) + _COMPRESS_EXT = "ptzst" +except ImportError: + def _compress(data: bytes) -> bytes: + return zlib.compress(data, level=9) + def _decompress(data: bytes) -> bytes: + return zlib.decompress(data) + _COMPRESS_EXT = "ptz" + +try: + from flash_attn_interface import flash_attn_func as _fa3_func +except ImportError: + _fa3_func = None + + +class Hyperparameters: + data_path = os.environ.get( + "DATA_PATH", "./data/datasets/fineweb10B_sp1024" + ) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + "./data/tokenizers/fineweb_1024_bpe.model", + ) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + mu_mom = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + mu_mom_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + mu_mom_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", 0.04)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + cosine_warmdown = bool(int(os.environ.get("COSINE_WARMDOWN", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + best_known_bpb = float(os.environ.get("BEST_KNOWN_BPB", 1.1520)) + regression_threshold = float(os.environ.get("REGRESSION_THRESHOLD", 0.005)) + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params: list[nn.Parameter], + lr: float, + momentum: float, + backend_steps: int, + weight_decay: float = 0.0, + ) -> None: + if ( + isinstance(params, list) + and len(params) > 0 + and isinstance(params[0], nn.Parameter) + ): + params = sorted( + params, key=lambda x: x.numel(), reverse=True + ) + defaults = dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + weight_decay=weight_decay, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, closure: Callable[[], Tensor] | None = None, + ) -> Tensor | None: + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + + pad_count = (-len(params)) % world_size + params_pad = params + [torch.empty_like(params[-1]) for _ in range(pad_count)] + for base_i in range(0, len(params_pad), world_size): + if base_i + rank < len(params): + p = params[base_i + rank] + if p.grad is None: + p.grad = torch.zeros_like(p) + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(p) + buf = state["momentum_buffer"] + buf.lerp_(p.grad, 1 - momentum) + wd = group.get("weight_decay", 0) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + update = zeropower_via_newtonschulz5(buf, steps=backend_steps) + update = update.to(p.dtype) + update *= max(1, p.size(-2) / p.size(-1)) ** 0.5 + p.add_(update, alpha=-lr) + if distributed: + chunk = params_pad[base_i:base_i + world_size] + dist.all_gather(chunk, params_pad[base_i + rank]) + + return loss + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + bytes_np = np.zeros((table_size,), dtype=np.int16) + space_np = np.zeros((table_size,), dtype=np.bool_) + boundary_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + boundary_np[token_id] = False + if sp.is_byte(token_id): + bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + space_np[token_id] = True + piece = piece[1:] + bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(bytes_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(boundary_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + ga_steps: int, + val_tokens: Tensor, + bytes_lut: Tensor, + space_lut: Tensor, + boundary_lut: Tensor, +) -> tuple[float, float]: + lbt = args.val_batch_size // (world_size * ga_steps) + if lbt < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small for world_size/seq_len") + lbs = lbt // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for bss in range(seq_start, seq_end, lbs): + bse = min(bss + lbs, seq_end) + raw_start = bss * args.train_seq_len + raw_end = bse * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True, + ) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (space_lut[tgt_ids] & ~boundary_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CTRL_PATTERNS = tuple( + p for p in os.environ.get( + "CTRL_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales," + "resid_mix,resid_mixes,q_gain,skip_weight," + "skip_weights,smear_gate,ve_layer_scale", + ).split(",") if p +) +SMALL_NUMEL = 65_535 + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def _hadamard_matrix(n: int) -> Tensor: + h = torch.tensor([[1.0]]) + while h.size(0) < n: + h = torch.cat([torch.cat([h, h], 1), torch.cat([h, -h], 1)], 0) / math.sqrt(2) + return h + + +def _hadamard_rotate(t: Tensor) -> Tensor: + rows, cols = t.shape + block = 1 + while block * 2 <= cols and cols % (block * 2) == 0: + block *= 2 + if block < 2: + return t + H = _hadamard_matrix(block).to(t.device, t.dtype) + return (t.reshape(rows, -1, block) @ H).reshape(rows, cols) + + +def quantize_tensor_int6(t: Tensor, bits: int = 6) -> tuple[Tensor, Tensor]: + max_val = 2 ** (bits - 1) - 1 + min_val = -(max_val + 1) + t32 = t.float() + if t32.ndim == 2: + t32 = _hadamard_rotate(t32) + row_max = t32.abs().amax(dim=1) + scale = (row_max / max_val).clamp_min(1.0 / max_val).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), min_val, max_val).to(torch.int8) + return q.contiguous(), scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-12).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), min_val, max_val).to(torch.int8) + return q.contiguous(), scale.contiguous() + + +def quantize_state_dict_int6( + state_dict: dict[str, Tensor], +) -> tuple[dict[str, dict[str, Tensor] | dict[str, str]], dict[str, int]]: + w: dict[str, Tensor] = {} + m: dict[str, str] = {} + stats = dict.fromkeys( + ( + "param_count", "num_tensors", "num_int6_tensors", + "num_passthrough_tensors", + "baseline_tensor_bytes", "int6_payload_bytes", + ), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_passthrough_tensors"] += 1 + w[name] = t + m[name] = "raw" + stats["int6_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= SMALL_NUMEL or t.ndim < 2: + kept = t.to(dtype=torch.float16).contiguous() + w[name] = kept + m[name] = str(t.dtype).removeprefix("torch.") + stats["num_passthrough_tensors"] += 1 + stats["int6_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_int6_tensors"] += 1 + q, s = quantize_tensor_int6(t) + w[name + ".q"] = q + w[name + ".s"] = s + m[name] = str(t.dtype).removeprefix("torch.") + nbytes = tensor_nbytes(q) + tensor_nbytes(s) + stats["int6_payload_bytes"] += nbytes + return {"w": w, "m": m}, stats + + +def dequantize_state_dict_int6( + obj: dict[str, object], +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + w = obj["w"] + m = obj["m"] + for name, dtype_str in m.items(): + if name + ".q" in w: + q = w[name + ".q"] + s = w[name + ".s"] + dtype = getattr(torch, dtype_str) + if s.ndim > 0: + sf = s.to(dtype=torch.float32) + expand = (q.shape[0], *([1] * (q.ndim - 1))) + dq = q.float() * sf.view(*expand) + if dq.ndim == 2: + dq = _hadamard_rotate(dq) + out[name] = dq.to(dtype=dtype).contiguous() + else: + sv = float(s.item()) + out[name] = (q.float() * sv).to(dtype).contiguous() + else: + out[name] = w[name].detach().to("cpu").contiguous() + return out + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.tokens = load_data_shard(self.files[0]) + self.pos = 0 + + def _advance_file(self) -> None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__( + self, pattern: str, rank: int, world_size: int, device: torch.device, + ) -> None: + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch( + self, global_tokens: int, seq_len: int, ga_steps: int, + ) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * ga_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None) -> None: + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, input: Tensor) -> Tensor: + w = self.weight.to(input.dtype) + return F.linear(input, w, self.bias.to(input.dtype) if self.bias is not None else None) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + is_ctrl = any(p in name for p in CTRL_PATTERNS) + if (param.ndim < 2 or is_ctrl) and param.dtype != torch.float32: + param.data = param.data.float() + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int) -> None: + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) + nn.init.normal_(self.embed.weight, std=0.005) + + def forward(self, input_ids: Tensor) -> Tensor: + return self.proj(self.embed(input_ids)) + + +class BigramHash(nn.Module): + def __init__( + self, num_buckets: int, model_dim: int, inner_dim: int = 128, + ) -> None: + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, inner_dim) + self.proj = CastedLinear(inner_dim, model_dim, bias=False) + nn.init.normal_(self.emb.weight, std=0.005) + + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + idx = (prev * 36313 + input_ids * 27191) % self.num_buckets + return self.proj(self.emb(idx)) + + +class Rotary(nn.Module): + def __init__( + self, dim: int, base: float = 10000.0, rope_dims: int = 0, + ) -> None: + super().__init__() + rd = rope_dims if rope_dims > 0 else dim + freqs = torch.arange(0, rd, 2, dtype=torch.float32) / rd + inv_freq = 1.0 / (base ** freqs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.seq_len_cached = 0 + self.cos_cached: Tensor | None = None + self.sin_cached: Tensor | None = None + + def forward( + self, seq_len: int, device: torch.device, dtype: torch.dtype, + ) -> tuple[Tensor, Tensor]: + if ( + self.cos_cached is None + or self.sin_cached is None + or self.seq_len_cached != seq_len + or self.cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self.cos_cached = freqs.cos()[None, None, :, :] + self.sin_cached = freqs.sin()[None, None, :, :] + self.seq_len_cached = seq_len + return self.cos_cached.to(dtype=dtype), self.sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + x_rope = x[..., :rd] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x[..., rd:]), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + rope_dims: int = 0, + ) -> None: + super().__init__() + assert dim % num_heads == 0 and num_heads % num_kv_heads == 0 + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self._use_xsa = use_xsa + assert self.head_dim % 2 == 0 + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj.zero_init = True + self.proj.output_proj = True + self.q_gain = nn.Parameter(torch.full( + (num_heads,), qk_gain_init, dtype=torch.float32, + )) + self.rotary = Rotary( + self.head_dim, base=rope_base, rope_dims=rope_dims, + ) + self.attn_gate = CastedLinear(dim, num_heads, bias=False) + + def forward( + self, x: Tensor, ve_out: Tensor | None = None, + ) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + if ve_out is not None: + v = v + ve_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + v_orig = v + if _fa3_func is not None: + qt = q.transpose(1, 2) + kt = k.transpose(1, 2) + vt = v.transpose(1, 2) + y = _fa3_func(qt, kt, vt, causal=True).transpose(1, 2) + else: + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + gate = torch.sigmoid(self.attn_gate(x)).transpose(1, 2).unsqueeze(-1) + y = y * gate + if self._use_xsa: + y = y.transpose(1, 2) # [B, S, H, D] + v_t = v_orig.transpose(1, 2) # [B, S, Hkv, D] + group_size = self.num_heads // self.num_kv_heads + y_grouped = y.reshape(bsz, seqlen, self.num_kv_heads, group_size, self.head_dim) + vn = F.normalize(v_t, dim=-1).unsqueeze(3) # [B, S, Hkv, 1, D] + dot = (y_grouped * vn).sum(-1, keepdim=True) + y = (y_grouped - dot * vn).reshape(bsz, seqlen, -1) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__( + self, dim: int, mlp_mult: int, mlp_hidden: int = 0, + ) -> None: + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj.zero_init = True + self.proj.output_proj = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + mlp_hidden: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + rope_dims: int = 0, + ) -> None: + super().__init__() + self.ln_scale_factor = ( + 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + ) + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, use_xsa=use_xsa, rope_dims=rope_dims, + ) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter( + torch.ones(dim, dtype=torch.float32), + ) + self.mlp_scale = nn.Parameter( + torch.ones(dim, dtype=torch.float32), + ) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float(), + ) + + def forward( + self, x: Tensor, x0: Tensor, ve_out: Tensor | None = None, + ) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, ve_out=ve_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ) -> None: + super().__init__() + assert logit_softcap > 0.0 + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_layers_actual = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHash(2048, model_dim, inner_dim=128) + self.smear_gate = nn.Parameter(torch.zeros(model_dim)) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + n_skip = min(self.num_encoder_layers, self.num_decoder_layers) + self.num_skip_weights = n_skip + self.skip_weights = nn.Parameter( + torch.ones(n_skip, model_dim, dtype=torch.float32), + ) + kv_dim = num_kv_heads * (model_dim // num_heads) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_str = os.environ.get("VE_LAYERS", "9,10") + self.ve_layer_set = set( + int(x) for x in ve_str.split(",") if x + ) + if self.ve_layer_set: + self.ve_shared = ValueEmbedding( + vocab_size, ve_dim, kv_dim, + ) + self.ve_layer_scales = nn.ParameterList([ + nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + for _ in self.ve_layer_set + ]) + n_xsa = int(os.environ.get("XSA_LAYERS", 3)) + xsa_set = ( + set(range(num_layers - n_xsa, num_layers)) + if n_xsa > 0 else set() + ) + mlp_h = int(os.environ.get("MLP_HIDDEN", 0)) + ln_sc = bool(int(os.environ.get("LN_SCALE", "1"))) + rope_d = int(os.environ.get("ROPE_DIMS", 16)) + self.blocks = nn.ModuleList([ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, + rope_base, qk_gain_init, + use_xsa=(i in xsa_set), + mlp_hidden=mlp_h, layer_idx=i, + ln_scale=ln_sc, rope_dims=rope_d, + ) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self._init_weights() + + def _init_weights(self) -> None: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear): + if getattr(module, "zero_init", False): + nn.init.zeros_(module.weight) + else: + nn.init.orthogonal_(module.weight, gain=1.0) + if getattr(module, "output_proj", False): + with torch.no_grad(): + module.weight.mul_(1.0 / (2 * self.num_layers_actual) ** 0.5) + + def _ve_for_layer( + self, + layer: int, + ve_base: Tensor | None, + ve_sorted: list[int], + dtype: torch.dtype, + ) -> Tensor | None: + if ve_base is None or layer not in self.ve_layer_set: + return None + idx = ve_sorted.index(layer) + scale = self.ve_layer_scales[idx].to(dtype=dtype) + return ve_base * scale + + def _embed(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + x = self.tok_emb(input_ids) + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + g = torch.sigmoid(self.smear_gate.to(x.dtype)) + x = x + g * (F.pad(x[:, :-1], (0, 0, 1, 0)) - x) + return x, x + + def _run_blocks( + self, x: Tensor, x0: Tensor, input_ids: Tensor, + ) -> Tensor: + ve_base = ( + self.ve_shared(input_ids) if self.ve_layer_set else None + ) + ve_sorted = sorted(self.ve_layer_set) if self.ve_layer_set else [] + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + ve = self._ve_for_layer(i, ve_base, ve_sorted, x.dtype) + x = self.blocks[i](x, x0, ve_out=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + li = self.num_encoder_layers + i + if skips: + sw = self.skip_weights[i].to(dtype=x.dtype) + x = x + sw[None, None, :] * skips.pop() + ve = self._ve_for_layer(li, ve_base, ve_sorted, x.dtype) + x = self.blocks[li](x, x0, ve_out=ve) + return self.final_norm(x) + + def forward( + self, input_ids: Tensor, target_ids: Tensor, + ) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, input_ids) + x = x.reshape(-1, x.size(-1)) + w = self.tok_emb.weight.to(x.dtype) + logits = self.logit_softcap * torch.tanh( + F.linear(x, w) / self.logit_softcap, + ) + return F.cross_entropy(logits.float(), target_ids.reshape(-1)) + + def per_token_loss( + self, input_ids: Tensor, target_ids: Tensor, + ) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, input_ids) + w = self.tok_emb.weight.to(x.dtype) + logits = self.logit_softcap * torch.tanh( + F.linear(x, w) / self.logit_softcap, + ) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="none", + ) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + bytes_lut: Tensor, + space_lut: Tensor, + boundary_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + all_starts = list(range(0, total_tokens - seq_len + 1, stride)) + rank_starts = all_starts[rank::world_size] + batch_size = max(1, min(args.val_batch_size // seq_len, 64)) + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + with torch.inference_mode(): + for i in range(0, len(rank_starts), batch_size): + batch_starts = rank_starts[i:i + batch_size] + bsz = len(batch_starts) + x = torch.stack( + [val_tokens[s:s + seq_len] for s in batch_starts], + ).to(device, torch.int64) + y = torch.stack( + [val_tokens[s + 1:s + seq_len + 1] for s in batch_starts], + ).to(device, torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + losses = base_model.per_token_loss(x, y).view(bsz, seq_len) + for j, s in enumerate(batch_starts): + score_from = 0 if s == 0 else (seq_len - stride) + scored = losses[j, score_from:] + val_loss_sum += scored.to(torch.float64).sum() + val_token_count += scored.numel() + tgt = y[j, score_from:] + prev = x[j, score_from:] + tbytes = bytes_lut[tgt].to(torch.int16) + tbytes += (space_lut[tgt] & ~boundary_lut[prev]).to(torch.int16) + val_byte_count += tbytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + assert world_size > 0 and 8 % world_size == 0, f"Invalid WORLD_SIZE={world_size}" + ga_steps = 8 // world_size + grad_scale = 1.0 / ga_steps + assert torch.cuda.is_available(), "CUDA required" + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + is_main = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import ( + enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp, + ) + use_cudnn = bool(int(os.environ.get("USE_CUDNN_SDPA", "1"))) + enable_cudnn_sdp(use_cudnn) + enable_flash_sdp(not use_cudnn) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if is_main: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not is_main: + return + if console: + print(msg, flush=True) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version} PyTorch {torch.__version__}", console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + assert args.tokenizer_path.endswith(".model"), "Requires .model tokenizer" + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size, ( + f"Vocab mismatch: {args.vocab_size} vs {sp.vocab_size()}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + bytes_lut, space_lut, boundary_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + n_val = val_tokens.numel() - 1 + log0( + f"data: {dataset_dir.name} " + f"train_shards:{actual_train_files} val_tokens:{n_val}" + ) + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + attn_backend = "FA3" if _fa3_func is not None else "SDPA" + log0( + f"torch.compile: fullgraph=True " + f"attention={attn_backend}+XSA mlp=relu2_3x" + ) + cmodel = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = ( + DDP(cmodel, device_ids=[local_rank], broadcast_buffers=False) + if distributed else cmodel + ) + + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = { + n: t.detach().float().clone() + for n, t in base_model.state_dict().items() + } + log0(f"EMA: initialized (decay={args.ema_decay})") + + bnp = list(base_model.blocks.named_parameters()) + mat_params = [ + p + for name, p in bnp + if p.ndim == 2 and not any(pattern in name for pattern in CTRL_PATTERNS) + ] + sc_params = [ + p + for name, p in bnp + if p.ndim < 2 or any(pattern in name for pattern in CTRL_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + sc_params.append(base_model.skip_weights) + sc_params.append(base_model.smear_gate) + for p in base_model.bigram.parameters(): + sc_params.append(p) + if base_model.ve_layer_set: + for p in base_model.ve_shared.parameters(): + sc_params.append(p) + for p in base_model.ve_layer_scales.parameters(): + sc_params.append(p) + token_lr = args.tied_embed_lr + opt_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=True, + ) + opt_muon = Muon( + mat_params, + lr=args.matrix_lr, + momentum=args.mu_mom, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in opt_muon.param_groups: + group["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.AdamW( + [{"params": sc_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [opt_tok, opt_muon, opt_scalar] + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"params:{n_params} world:{world_size} accum:{ga_steps}") + log0( + f"lr: embed={token_lr} matrix={args.matrix_lr} " + f"scalar={args.scalar_lr}" + ) + log0( + f"batch:{args.train_batch_tokens} seq:{args.train_seq_len} " + f"iters:{args.iterations} warmup:{args.warmup_steps} " + f"wall:{args.max_wallclock_seconds:.0f}s" + ) + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + wall_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if wd_iters <= 0: + return 1.0 + if wall_ms is None: + wd_start = max(args.iterations - wd_iters, 0) + if wd_start <= step < args.iterations: + progress = 1.0 - (args.iterations - step) / max(wd_iters, 1) + if args.cosine_warmdown: + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 - progress + return 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = wd_iters * step_ms + rem_ms = max(wall_ms - elapsed_ms, 0.0) + if rem_ms <= wd_ms: + linear = rem_ms / max(wd_ms, 1e-9) + if args.cosine_warmdown: + progress = 1.0 - linear + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return linear + return 1.0 + + wd_iters = args.warmdown_iters + + if args.warmup_steps > 0: + init_sd = { + n: t.detach().cpu().clone() + for n, t in base_model.state_dict().items() + } + init_opts = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + torch.cuda.synchronize() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(ga_steps): + if distributed: + model.require_backward_grad_sync = micro_step == ga_steps - 1 + x, y = train_loader.next_batch( + args.train_batch_tokens, args.train_seq_len, ga_steps, + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + should_log_warmup = ( + args.warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == args.warmup_steps + ) + if should_log_warmup: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + torch.cuda.synchronize() + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if name in init_sd: + param.data.copy_(init_sd[name].to(param.device, dtype=param.dtype)) + for name, buf in base_model.named_buffers(): + if name in init_sd: + buf.data.copy_(init_sd[name].to(buf.device, dtype=buf.dtype)) + for opt, state in zip(optimizers, init_opts, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + + train_ms = 0.0 + stop_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_step is not None and step >= stop_step) + + should_validate = ( + last_step + or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + ) + if should_validate: + torch.cuda.synchronize() + train_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, + ) + for mod in base_model.modules(): + if isinstance(mod, Rotary): + mod.seq_len_cached = 0 + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_ms:.0f}ms step_avg:{train_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_step is not None and step < args.iterations: + log0(f"stopping_early: step:{step}/{args.iterations} time:{train_ms:.0f}ms") + break + + elapsed_ms = train_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(ga_steps): + if distributed: + model.require_backward_grad_sync = micro_step == ga_steps - 1 + x, y = train_loader.next_batch( + args.train_batch_tokens, args.train_seq_len, ga_steps, + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= ga_steps + + frac = ( + min(step / args.mu_mom_warmup_steps, 1.0) + if args.mu_mom_warmup_steps > 0 else 1.0 + ) + mu_mom = ( + (1 - frac) * args.mu_mom_warmup_start + frac * args.mu_mom + ) + for group in opt_muon.param_groups: + group["momentum"] = mu_mom + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + step += 1 + elapsed = train_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{elapsed:.0f}ms step_avg:{elapsed / step:.2f}ms" + ) + + reached_cap = wall_ms is not None and elapsed >= wall_ms + if distributed and wall_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_step is None and reached_cap: + stop_step = step + + alloc_mb = torch.cuda.max_memory_allocated() // 1024 // 1024 + resv_mb = torch.cuda.max_memory_reserved() // 1024 // 1024 + log0(f"peak_mem: {alloc_mb}MiB alloc {resv_mb}MiB reserved") + + if ema_state is not None: + log0(f"EMA: loading averaged weights (decay={args.ema_decay})") + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if name in ema_state: + param.data.copy_(ema_state[name].to(param.device, dtype=param.dtype)) + for name, buf in base_model.named_buffers(): + if name in ema_state: + buf.data.copy_(ema_state[name].to(buf.device, dtype=buf.dtype)) + ema_state = None + if distributed: + dist.barrier() + + torch.cuda.synchronize() + pre_q_loss, pre_q_bpb = eval_val( + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, + ) + log0(f"pre_quant val_loss:{pre_q_loss:.4f} val_bpb:{pre_q_bpb:.4f}") + if pre_q_bpb > args.best_known_bpb + args.regression_threshold: + log0( + f"REGRESSION WARNING: pre_quant {pre_q_bpb:.4f} > " + f"best_known {args.best_known_bpb:.4f} + " + f"threshold {args.regression_threshold}" + ) + + quant_obj, _ = quantize_state_dict_int6(base_model.state_dict()) + artifact_name = f"final_model.int6.{_COMPRESS_EXT}" + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_blob = _compress(quant_buf.getvalue()) + qf_bytes = 0 + if is_main: + with open(artifact_name, "wb") as f: + f.write(quant_blob) + qf_bytes = os.path.getsize(artifact_name) + code_bytes = len(code.encode("utf-8")) + total = qf_bytes + code_bytes + payload_bytes = len(quant_buf.getvalue()) + ratio = payload_bytes / qf_bytes if qf_bytes > 0 else 0 + log0( + f"artifact: {qf_bytes}B model + {code_bytes}B code " + f"= {total}B (compression:{ratio:.2f}x)" + ) + if ratio < 1.6: + log0( + f"COMPRESSION ALERT: {ratio:.2f}x below 1.6x " + f"-- weight distributions may be degraded" + ) + if total > 16_000_000: + log0(f"FATAL: exceeds 16MB by {total - 16_000_000}B") + if distributed: + dist.destroy_process_group() + sys.exit(0) + else: + log0(f"headroom: {16_000_000 - total}B") + + if distributed: + dist.barrier() + with open(artifact_name, "rb") as f: + qblob = f.read() + raw = _decompress(qblob) + quant_state = torch.load(io.BytesIO(raw), map_location="cpu") + dequant_sd = dequantize_state_dict_int6(quant_state) + model_sd = base_model.state_dict() + missing = [k for k in model_sd if k not in dequant_sd] + if missing: + for k in missing: + dequant_sd[k] = model_sd[k].cpu() + base_model.load_state_dict(dequant_sd, strict=False) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, + ) + torch.cuda.synchronize() + q_eval_ms = 1000.0 * (time.perf_counter() - t_qeval) + log0( + f"final_int6_{_COMPRESS_EXT}_roundtrip " + f"val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{q_eval_ms:.0f}ms" + ) + + sw_val_bpb = q_val_bpb + if args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, bytes_lut, space_lut, boundary_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + sw_ms = 1000.0 * (time.perf_counter() - t_sw) + log0( + f"final_int6_{_COMPRESS_EXT}_sliding " + f"val_loss:{sw_val_loss:.4f} " + f"val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{sw_ms:.0f}ms" + ) + + if is_main: + avg_ms = train_ms / max(step, 1) + log0( + f"RUN_SUMMARY: steps={step} " + f"pre_quant_bpb={pre_q_bpb:.4f} " + f"post_quant_bpb={q_val_bpb:.4f} " + f"sliding_bpb={sw_val_bpb:.4f} " + f"artifact_bytes={qf_bytes} step_ms={avg_ms:.1f}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed1337.log b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed1337.log new file mode 100644 index 0000000000..ce07f4dd54 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed1337.log @@ -0,0 +1,97 @@ +logs/879c18aa-c3f4-4591-aa65-3e99c095fd71.txt +data: fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +torch.compile: fullgraph=True attention=SDPA+XSA mlp=relu2_3x +EMA: initialized (decay=0.997) +params:27038810 world:8 accum:1 +lr: embed=0.035 matrix=0.025 scalar=0.025 +batch:524288 seq:2048 iters:20000 warmup:20 wall:600s +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9302 train_time:35185ms step_avg:35185.09ms +step:2/20000 train_loss:8.1740 train_time:35254ms step_avg:17626.81ms +step:3/20000 train_loss:8.1965 train_time:35325ms step_avg:11775.12ms +step:4/20000 train_loss:8.2060 train_time:35397ms step_avg:8849.36ms +step:5/20000 train_loss:8.1728 train_time:35468ms step_avg:7093.69ms +step:6/20000 train_loss:8.4573 train_time:35541ms step_avg:5923.54ms +step:7/20000 train_loss:8.2027 train_time:35612ms step_avg:5087.46ms +step:8/20000 train_loss:8.2893 train_time:35681ms step_avg:4460.12ms +step:9/20000 train_loss:8.2407 train_time:35750ms step_avg:3972.28ms +step:10/20000 train_loss:8.1550 train_time:35819ms step_avg:3581.93ms +step:200/20000 train_loss:3.2141 train_time:49169ms step_avg:245.84ms +step:400/20000 train_loss:2.3816 train_time:63085ms step_avg:157.71ms +step:600/20000 train_loss:2.5192 train_time:77026ms step_avg:128.38ms +step:800/20000 train_loss:2.2491 train_time:90928ms step_avg:113.66ms +step:1000/20000 train_loss:2.3423 train_time:104842ms step_avg:104.84ms +step:1000/20000 val_loss:2.2995 val_bpb:1.3619 train_time:104847ms step_avg:104.85ms +step:1200/20000 train_loss:2.3621 train_time:118799ms step_avg:99.00ms +step:1400/20000 train_loss:2.3978 train_time:132738ms step_avg:94.81ms +step:1600/20000 train_loss:2.0624 train_time:146676ms step_avg:91.67ms +step:1800/20000 train_loss:2.1723 train_time:160613ms step_avg:89.23ms +step:2000/20000 train_loss:2.2001 train_time:174542ms step_avg:87.27ms +step:2000/20000 val_loss:2.1853 val_bpb:1.2943 train_time:174547ms step_avg:87.27ms +step:2200/20000 train_loss:2.0234 train_time:188491ms step_avg:85.68ms +step:2400/20000 train_loss:2.1471 train_time:202484ms step_avg:84.37ms +step:2600/20000 train_loss:2.3810 train_time:216425ms step_avg:83.24ms +step:2800/20000 train_loss:2.1886 train_time:230367ms step_avg:82.27ms +step:3000/20000 train_loss:2.1713 train_time:244308ms step_avg:81.44ms +step:3000/20000 val_loss:2.1426 val_bpb:1.2689 train_time:244313ms step_avg:81.44ms +step:3200/20000 train_loss:2.1435 train_time:258224ms step_avg:80.69ms +step:3400/20000 train_loss:2.1121 train_time:272163ms step_avg:80.05ms +step:3600/20000 train_loss:2.0538 train_time:286101ms step_avg:79.47ms +step:3800/20000 train_loss:2.1575 train_time:300021ms step_avg:78.95ms +step:4000/20000 train_loss:2.1247 train_time:313992ms step_avg:78.50ms +step:4000/20000 val_loss:2.1200 val_bpb:1.2556 train_time:313999ms step_avg:78.50ms +step:4200/20000 train_loss:2.1211 train_time:328046ms step_avg:78.11ms +step:4400/20000 train_loss:2.0647 train_time:342000ms step_avg:77.73ms +step:4600/20000 train_loss:1.9224 train_time:355948ms step_avg:77.38ms +step:4800/20000 train_loss:2.2042 train_time:369871ms step_avg:77.06ms +step:5000/20000 train_loss:1.9586 train_time:383797ms step_avg:76.76ms +step:5000/20000 val_loss:2.0971 val_bpb:1.2420 train_time:383803ms step_avg:76.76ms +step:5200/20000 train_loss:2.1175 train_time:397728ms step_avg:76.49ms +step:5400/20000 train_loss:2.1221 train_time:411657ms step_avg:76.23ms +step:5600/20000 train_loss:2.1074 train_time:425618ms step_avg:76.00ms +step:5800/20000 train_loss:2.0562 train_time:439573ms step_avg:75.79ms +step:6000/20000 train_loss:2.1282 train_time:453550ms step_avg:75.59ms +step:6000/20000 val_loss:2.0521 val_bpb:1.2154 train_time:453556ms step_avg:75.59ms +step:6200/20000 train_loss:1.9915 train_time:467519ms step_avg:75.41ms +step:6400/20000 train_loss:2.0542 train_time:481474ms step_avg:75.23ms +step:6600/20000 train_loss:2.0014 train_time:495454ms step_avg:75.07ms +step:6800/20000 train_loss:2.0479 train_time:509425ms step_avg:74.92ms +step:7000/20000 train_loss:2.0813 train_time:523398ms step_avg:74.77ms +step:7000/20000 val_loss:1.9836 val_bpb:1.1748 train_time:523403ms step_avg:74.77ms +step:7200/20000 train_loss:2.0387 train_time:537341ms step_avg:74.63ms +step:7400/20000 train_loss:1.9607 train_time:551305ms step_avg:74.50ms +step:7600/20000 train_loss:1.8286 train_time:565247ms step_avg:74.37ms +step:7800/20000 train_loss:1.9758 train_time:579195ms step_avg:74.26ms +step:8000/20000 train_loss:1.9411 train_time:593143ms step_avg:74.14ms +step:8000/20000 val_loss:1.9435 val_bpb:1.1511 train_time:593149ms step_avg:74.14ms +step:8098/20000 val_loss:1.9434 val_bpb:1.1510 train_time:600034ms step_avg:74.10ms +stopping_early: step:8098/20000 time:600034ms +peak_mem: 14900MiB alloc 14972MiB reserved +EMA: loading averaged weights (decay=0.997) +pre_quant val_loss:1.9438 val_bpb:1.1512 +REGRESSION WARNING: pre_quant 1.1512 > best_known 1.1414 + threshold 0.005 +artifact: 15567435B model + 51283B code = 15618718B (compression:1.75x) +headroom: 381282B +final_int6_ptzst_roundtrip val_loss:1.9582 val_bpb:1.1597 eval_time:2170ms +final_int6_ptzst_sliding val_loss:1.9188 val_bpb:1.1364 stride:64 eval_time:199279ms +RUN_SUMMARY: steps=8098 pre_quant_bpb=1.1512 post_quant_bpb=1.1597 sliding_bpb=1.1364 artifact_bytes=15567435 step_ms=74.1 diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed2024.log b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed2024.log new file mode 100644 index 0000000000..418606239b --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed2024.log @@ -0,0 +1,94 @@ +logs/8fd8b7d5-7688-4339-b135-ee313135d7ce.txt +data: fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +torch.compile: fullgraph=True attention=SDPA+XSA mlp=relu2_3x +EMA: initialized (decay=0.997) +params:27038810 world:8 accum:1 +lr: embed=0.035 matrix=0.025 scalar=0.025 +batch:524288 seq:2048 iters:20000 warmup:20 wall:600s +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9302 train_time:32137ms step_avg:32136.90ms +step:2/20000 train_loss:8.1740 train_time:32202ms step_avg:16101.05ms +step:3/20000 train_loss:8.1965 train_time:32273ms step_avg:10757.82ms +step:4/20000 train_loss:8.2059 train_time:32342ms step_avg:8085.48ms +step:5/20000 train_loss:8.1722 train_time:32411ms step_avg:6482.13ms +step:6/20000 train_loss:8.4560 train_time:32479ms step_avg:5413.24ms +step:7/20000 train_loss:8.1999 train_time:32550ms step_avg:4649.93ms +step:8/20000 train_loss:8.2832 train_time:32621ms step_avg:4077.64ms +step:9/20000 train_loss:8.2225 train_time:32694ms step_avg:3632.62ms +step:10/20000 train_loss:8.1167 train_time:32762ms step_avg:3276.16ms +step:200/20000 train_loss:3.1681 train_time:47234ms step_avg:236.17ms +step:400/20000 train_loss:2.3496 train_time:63657ms step_avg:159.14ms +step:600/20000 train_loss:2.5073 train_time:79553ms step_avg:132.59ms +step:800/20000 train_loss:2.2398 train_time:96690ms step_avg:120.86ms +step:1000/20000 train_loss:2.3437 train_time:113380ms step_avg:113.38ms +step:1000/20000 val_loss:2.2937 val_bpb:1.3585 train_time:113386ms step_avg:113.39ms +step:1200/20000 train_loss:2.3642 train_time:127032ms step_avg:105.86ms +step:1400/20000 train_loss:2.3999 train_time:140758ms step_avg:100.54ms +step:1600/20000 train_loss:2.0578 train_time:154377ms step_avg:96.49ms +step:1800/20000 train_loss:2.1635 train_time:168177ms step_avg:93.43ms +step:2000/20000 train_loss:2.1966 train_time:181820ms step_avg:90.91ms +step:2000/20000 val_loss:2.1836 val_bpb:1.2932 train_time:181829ms step_avg:90.91ms +step:2200/20000 train_loss:2.0160 train_time:195547ms step_avg:88.89ms +step:2400/20000 train_loss:2.1438 train_time:213978ms step_avg:89.16ms +step:2600/20000 train_loss:2.3781 train_time:231842ms step_avg:89.17ms +step:2800/20000 train_loss:2.1869 train_time:245573ms step_avg:87.70ms +step:3000/20000 train_loss:2.1713 train_time:259511ms step_avg:86.50ms +step:3000/20000 val_loss:2.1400 val_bpb:1.2674 train_time:259518ms step_avg:86.51ms +step:3200/20000 train_loss:2.1434 train_time:273201ms step_avg:85.38ms +step:3400/20000 train_loss:2.1119 train_time:287085ms step_avg:84.44ms +step:3600/20000 train_loss:2.0505 train_time:300797ms step_avg:83.55ms +step:3800/20000 train_loss:2.1564 train_time:314500ms step_avg:82.76ms +step:4000/20000 train_loss:2.1241 train_time:328168ms step_avg:82.04ms +step:4000/20000 val_loss:2.1199 val_bpb:1.2555 train_time:328175ms step_avg:82.04ms +step:4200/20000 train_loss:2.1155 train_time:342247ms step_avg:81.49ms +step:4400/20000 train_loss:2.0627 train_time:355956ms step_avg:80.90ms +step:4600/20000 train_loss:1.9167 train_time:369615ms step_avg:80.35ms +step:4800/20000 train_loss:2.2035 train_time:383224ms step_avg:79.84ms +step:5000/20000 train_loss:1.9520 train_time:396874ms step_avg:79.37ms +step:5000/20000 val_loss:2.0917 val_bpb:1.2388 train_time:396881ms step_avg:79.38ms +step:5200/20000 train_loss:2.1133 train_time:410571ms step_avg:78.96ms +step:5400/20000 train_loss:2.1172 train_time:424250ms step_avg:78.56ms +step:5600/20000 train_loss:2.0995 train_time:438004ms step_avg:78.22ms +step:5800/20000 train_loss:2.0446 train_time:451834ms step_avg:77.90ms +step:6000/20000 train_loss:2.1197 train_time:465487ms step_avg:77.58ms +step:6000/20000 val_loss:2.0444 val_bpb:1.2108 train_time:465494ms step_avg:77.58ms +step:6200/20000 train_loss:1.9815 train_time:479161ms step_avg:77.28ms +step:6400/20000 train_loss:2.0438 train_time:492845ms step_avg:77.01ms +step:6600/20000 train_loss:1.9894 train_time:506712ms step_avg:76.77ms +step:6800/20000 train_loss:2.0419 train_time:520475ms step_avg:76.54ms +step:7000/20000 train_loss:2.0719 train_time:534194ms step_avg:76.31ms +step:7000/20000 val_loss:1.9760 val_bpb:1.1703 train_time:534202ms step_avg:76.31ms +step:7200/20000 train_loss:2.0334 train_time:547866ms step_avg:76.09ms +step:7400/20000 train_loss:1.9563 train_time:561538ms step_avg:75.88ms +step:7600/20000 train_loss:1.8287 train_time:575284ms step_avg:75.70ms +step:7800/20000 train_loss:1.9733 train_time:589045ms step_avg:75.52ms +step:7960/20000 val_loss:1.9450 val_bpb:1.1519 train_time:600027ms step_avg:75.38ms +stopping_early: step:7960/20000 time:600027ms +peak_mem: 14901MiB alloc 15126MiB reserved +EMA: loading averaged weights (decay=0.997) +pre_quant val_loss:1.9453 val_bpb:1.1521 +artifact: 15548351B model + 52010B code = 15600361B (compression:1.76x) +headroom: 399639B +final_int6_ptzst_roundtrip val_loss:1.9593 val_bpb:1.1604 eval_time:2138ms +final_int6_ptzst_sliding val_loss:1.9199 val_bpb:1.1370 stride:64 eval_time:198714ms +RUN_SUMMARY: steps=7960 pre_quant_bpb=1.1521 post_quant_bpb=1.1604 sliding_bpb=1.1370 artifact_bytes=15548351 step_ms=75.4 diff --git a/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed42.log b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed42.log new file mode 100644 index 0000000000..4045468bde --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_11L_Hadamard_VE128_cuDNN_1.1365/train_seed42.log @@ -0,0 +1,96 @@ +logs/d18f3726-d3f7-4a2f-bf7b-cf13400eaa0e.txt +data: fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +torch.compile: fullgraph=True attention=SDPA+XSA mlp=relu2_3x +EMA: initialized (decay=0.997) +params:27038810 world:8 accum:1 +lr: embed=0.035 matrix=0.025 scalar=0.025 +batch:524288 seq:2048 iters:20000 warmup:20 wall:600s +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9302 train_time:41682ms step_avg:41681.93ms +step:2/20000 train_loss:8.1740 train_time:41747ms step_avg:20873.35ms +step:3/20000 train_loss:8.1965 train_time:41819ms step_avg:13939.69ms +step:4/20000 train_loss:8.2061 train_time:41888ms step_avg:10472.04ms +step:5/20000 train_loss:8.1732 train_time:41959ms step_avg:8391.89ms +step:6/20000 train_loss:8.4584 train_time:42028ms step_avg:7004.59ms +step:7/20000 train_loss:8.2048 train_time:42095ms step_avg:6013.55ms +step:8/20000 train_loss:8.2917 train_time:42166ms step_avg:5270.81ms +step:9/20000 train_loss:8.2476 train_time:42241ms step_avg:4693.50ms +step:10/20000 train_loss:8.1753 train_time:42312ms step_avg:4231.18ms +step:200/20000 train_loss:3.2650 train_time:55440ms step_avg:277.20ms +step:400/20000 train_loss:2.4006 train_time:69178ms step_avg:172.95ms +step:600/20000 train_loss:2.5356 train_time:82913ms step_avg:138.19ms +step:800/20000 train_loss:2.2550 train_time:96642ms step_avg:120.80ms +step:1000/20000 train_loss:2.3464 train_time:110340ms step_avg:110.34ms +step:1000/20000 val_loss:2.3023 val_bpb:1.3636 train_time:110347ms step_avg:110.35ms +step:1200/20000 train_loss:2.3672 train_time:124126ms step_avg:103.44ms +step:1400/20000 train_loss:2.4040 train_time:137867ms step_avg:98.48ms +step:1600/20000 train_loss:2.0601 train_time:151593ms step_avg:94.75ms +step:1800/20000 train_loss:2.1667 train_time:165296ms step_avg:91.83ms +step:2000/20000 train_loss:2.2038 train_time:179056ms step_avg:89.53ms +step:2000/20000 val_loss:2.1858 val_bpb:1.2945 train_time:179064ms step_avg:89.53ms +step:2200/20000 train_loss:2.0256 train_time:192725ms step_avg:87.60ms +step:2400/20000 train_loss:2.1463 train_time:206451ms step_avg:86.02ms +step:2600/20000 train_loss:2.3789 train_time:220144ms step_avg:84.67ms +step:2800/20000 train_loss:2.1876 train_time:233884ms step_avg:83.53ms +step:3000/20000 train_loss:2.1764 train_time:247627ms step_avg:82.54ms +step:3000/20000 val_loss:2.1417 val_bpb:1.2684 train_time:247634ms step_avg:82.54ms +step:3200/20000 train_loss:2.1392 train_time:261340ms step_avg:81.67ms +step:3400/20000 train_loss:2.1132 train_time:275031ms step_avg:80.89ms +step:3600/20000 train_loss:2.0513 train_time:288728ms step_avg:80.20ms +step:3800/20000 train_loss:2.1553 train_time:302382ms step_avg:79.57ms +step:4000/20000 train_loss:2.1296 train_time:316116ms step_avg:79.03ms +step:4000/20000 val_loss:2.1197 val_bpb:1.2554 train_time:316125ms step_avg:79.03ms +step:4200/20000 train_loss:2.1234 train_time:329910ms step_avg:78.55ms +step:4400/20000 train_loss:2.0654 train_time:343618ms step_avg:78.10ms +step:4600/20000 train_loss:1.9248 train_time:357282ms step_avg:77.67ms +step:4800/20000 train_loss:2.2076 train_time:370998ms step_avg:77.29ms +step:5000/20000 train_loss:1.9657 train_time:384722ms step_avg:76.94ms +step:5000/20000 val_loss:2.1007 val_bpb:1.2441 train_time:384731ms step_avg:76.95ms +step:5200/20000 train_loss:2.1237 train_time:398418ms step_avg:76.62ms +step:5400/20000 train_loss:2.1258 train_time:412136ms step_avg:76.32ms +step:5600/20000 train_loss:2.1142 train_time:425829ms step_avg:76.04ms +step:5800/20000 train_loss:2.0558 train_time:439574ms step_avg:75.79ms +step:6000/20000 train_loss:2.1329 train_time:453260ms step_avg:75.54ms +step:6000/20000 val_loss:2.0569 val_bpb:1.2182 train_time:453267ms step_avg:75.54ms +step:6200/20000 train_loss:1.9960 train_time:466971ms step_avg:75.32ms +step:6400/20000 train_loss:2.0602 train_time:480686ms step_avg:75.11ms +step:6600/20000 train_loss:2.0068 train_time:494438ms step_avg:74.91ms +step:6800/20000 train_loss:2.0529 train_time:508148ms step_avg:74.73ms +step:7000/20000 train_loss:2.0868 train_time:521846ms step_avg:74.55ms +step:7000/20000 val_loss:1.9881 val_bpb:1.1774 train_time:521855ms step_avg:74.55ms +step:7200/20000 train_loss:2.0408 train_time:535543ms step_avg:74.38ms +step:7400/20000 train_loss:1.9650 train_time:549269ms step_avg:74.23ms +step:7600/20000 train_loss:1.8298 train_time:563006ms step_avg:74.08ms +step:7800/20000 train_loss:1.9748 train_time:576735ms step_avg:73.94ms +step:8000/20000 train_loss:1.9411 train_time:590414ms step_avg:73.80ms +step:8000/20000 val_loss:1.9438 val_bpb:1.1512 train_time:590424ms step_avg:73.80ms +step:8102/20000 val_loss:1.9436 val_bpb:1.1511 train_time:599981ms step_avg:74.05ms +stopping_early: step:8102/20000 time:599981ms +peak_mem: 14901MiB alloc 15072MiB reserved +EMA: loading averaged weights (decay=0.997) +pre_quant val_loss:1.9440 val_bpb:1.1513 +artifact: 15577530B model + 52010B code = 15629540B (compression:1.75x) +headroom: 370460B +final_int6_ptzst_roundtrip val_loss:1.9576 val_bpb:1.1594 eval_time:2158ms +final_int6_ptzst_sliding val_loss:1.9182 val_bpb:1.1361 stride:64 eval_time:198673ms +RUN_SUMMARY: steps=8102 pre_quant_bpb=1.1513 post_quant_bpb=1.1594 sliding_bpb=1.1361 artifact_bytes=15577530 step_ms=74.1 diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..cace894b63 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -12,11 +12,11 @@ import math import os import random -import subprocess import sys import time import uuid import zlib +from collections.abc import Callable from pathlib import Path import numpy as np @@ -27,75 +27,87 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +try: + import zstandard + def _compress(data: bytes) -> bytes: + return zstandard.ZstdCompressor(level=22).compress(data) + def _decompress(data: bytes) -> bytes: + return zstandard.ZstdDecompressor().decompress(data) + _COMPRESS_EXT = "ptzst" +except ImportError: + def _compress(data: bytes) -> bytes: + return zlib.compress(data, level=9) + def _decompress(data: bytes) -> bytes: + return zlib.decompress(data) + _COMPRESS_EXT = "ptz" + +try: + from flash_attn_interface import flash_attn_func as _fa3_func +except ImportError: + _fa3_func = None + class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + data_path = os.environ.get( + "DATA_PATH", "./data/datasets/fineweb10B_sp1024" + ) train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + "./data/tokenizers/fineweb_1024_bpe.model", + ) run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + mu_mom = float(os.environ.get("MUON_MOMENTUM", 0.99)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + mu_mom_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + mu_mom_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", 0.04)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + xsa_layers = int(os.environ.get("XSA_LAYERS", 4)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + cosine_warmdown = bool(int(os.environ.get("COSINE_WARMDOWN", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + best_known_bpb = float(os.environ.get("BEST_KNOWN_BPB", 1.1520)) + regression_threshold = float(os.environ.get("REGRESSION_THRESHOLD", 0.005)) -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -110,14 +122,34 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + def __init__( + self, + params: list[nn.Parameter], + lr: float, + momentum: float, + backend_steps: int, + weight_decay: float = 0.0, + ) -> None: + if ( + isinstance(params, list) + and len(params) > 0 + and isinstance(params[0], nn.Parameter) + ): + params = sorted( + params, key=lambda x: x.numel(), reverse=True + ) + defaults = dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + weight_decay=weight_decay, ) + super().__init__(params, defaults) @torch.no_grad() - def step(self, closure=None): + def step( + self, closure: Callable[[], Tensor] | None = None, + ) -> Tensor | None: loss = None if closure is not None: with torch.enable_grad(): @@ -134,73 +166,57 @@ def step(self, closure=None): lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad + pad_count = (-len(params)) % world_size + params_pad = params + [torch.empty_like(params[-1]) for _ in range(pad_count)] + for base_i in range(0, len(params_pad), world_size): + if base_i + rank < len(params): + p = params[base_i + rank] + if p.grad is None: + p.grad = torch.zeros_like(p) state = self.state[p] if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) + state["momentum_buffer"] = torch.zeros_like(p) buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() + buf.lerp_(p.grad, 1 - momentum) + wd = group.get("weight_decay", 0) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + update = zeropower_via_newtonschulz5(buf, steps=backend_steps) + update = update.to(p.dtype) + update *= max(1, p.size(-2) / p.size(-1)) ** 0.5 + p.add_(update, alpha=-lr) + if distributed: + chunk = params_pad[base_i:base_i + world_size] + dist.all_gather(chunk, params_pad[base_i + rank]) return loss -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: sp_vocab_size = int(sp.vocab_size()) table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + bytes_np = np.zeros((table_size,), dtype=np.int16) + space_np = np.zeros((table_size,), dtype=np.bool_) + boundary_np = np.ones((table_size,), dtype=np.bool_) for token_id in range(sp_vocab_size): if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue - is_boundary_token_np[token_id] = False + boundary_np[token_id] = False if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 + bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True + if piece.startswith("\u2581"): + space_np[token_id] = True piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) + bytes_np[token_id] = len(piece.encode("utf-8")) return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + torch.tensor(bytes_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(boundary_np, dtype=torch.bool, device=device), ) @@ -208,7 +224,6 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: @@ -222,23 +237,16 @@ def eval_val( rank: int, world_size: int, device: torch.device, - grad_accum_steps: int, + ga_steps: int, val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, + bytes_lut: Tensor, + space_lut: Tensor, + boundary_lut: Tensor, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len + lbt = args.val_batch_size // (world_size * ga_steps) + if lbt < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small for world_size/seq_len") + lbs = lbt // args.train_seq_len total_seqs = (val_tokens.numel() - 1) // args.train_seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size @@ -248,11 +256,13 @@ def eval_val( model.eval() with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + for bss in range(seq_start, seq_end, lbs): + bse = min(bss + lbs, seq_end) + raw_start = bss * args.train_seq_len + raw_end = bse * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True, + ) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): @@ -262,8 +272,8 @@ def eval_val( val_token_count += batch_token_count prev_ids = x.reshape(-1) tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + token_bytes = bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (space_lut[tgt_ids] & ~boundary_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() if dist.is_available() and dist.is_initialized(): @@ -277,160 +287,126 @@ def eval_val( model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern +CTRL_PATTERNS = tuple( + p for p in os.environ.get( + "CTRL_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales," + "resid_mix,resid_mixes,q_gain,skip_weight," + "skip_weights,smear_gate,ve_layer_scale", + ).split(",") if p ) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +SMALL_NUMEL = 65_535 + def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: +def _hadamard_matrix(n: int) -> Tensor: + h = torch.tensor([[1.0]]) + while h.size(0) < n: + h = torch.cat([torch.cat([h, h], 1), torch.cat([h, -h], 1)], 0) / math.sqrt(2) + return h + + +def _hadamard_rotate(t: Tensor) -> Tensor: + rows, cols = t.shape + block = 1 + while block * 2 <= cols and cols % (block * 2) == 0: + block *= 2 + if block < 2: + return t + H = _hadamard_matrix(block).to(t.device, t.dtype) + return (t.reshape(rows, -1, block) @ H).reshape(rows, cols) + + +def quantize_tensor_int6(t: Tensor, bits: int = 6) -> tuple[Tensor, Tensor]: + max_val = 2 ** (bits - 1) - 1 + min_val = -(max_val + 1) t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} + t32 = _hadamard_rotate(t32) + row_max = t32.abs().amax(dim=1) + scale = (row_max / max_val).clamp_min(1.0 / max_val).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), min_val, max_val).to(torch.int8) + return q.contiguous(), scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-12).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), min_val, max_val).to(torch.int8) + return q.contiguous(), scale.contiguous() + + +def quantize_state_dict_int6( + state_dict: dict[str, Tensor], +) -> tuple[dict[str, dict[str, Tensor] | dict[str, str]], dict[str, int]]: + w: dict[str, Tensor] = {} + m: dict[str, str] = {} stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + ( + "param_count", "num_tensors", "num_int6_tensors", + "num_passthrough_tensors", + "baseline_tensor_bytes", "int6_payload_bytes", + ), 0, ) - for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) + stats["num_passthrough_tensors"] += 1 + w[name] = t + m[name] = "raw" + stats["int6_payload_bytes"] += tensor_nbytes(t) continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) + if t.numel() <= SMALL_NUMEL or t.ndim < 2: + kept = t.to(dtype=torch.float16).contiguous() + w[name] = kept + m[name] = str(t.dtype).removeprefix("torch.") + stats["num_passthrough_tensors"] += 1 + stats["int6_payload_bytes"] += tensor_nbytes(kept) continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + stats["num_int6_tensors"] += 1 + q, s = quantize_tensor_int6(t) + w[name + ".q"] = q + w[name + ".s"] = s + m[name] = str(t.dtype).removeprefix("torch.") + nbytes = tensor_nbytes(q) + tensor_nbytes(s) + stats["int6_payload_bytes"] += nbytes + return {"w": w, "m": m}, stats + + +def dequantize_state_dict_int6( + obj: dict[str, object], +) -> dict[str, Tensor]: out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + w = obj["w"] + m = obj["m"] + for name, dtype_str in m.items(): + if name + ".q" in w: + q = w[name + ".q"] + s = w[name + ".s"] + dtype = getattr(torch, dtype_str) + if s.ndim > 0: + sf = s.to(dtype=torch.float32) + expand = (q.shape[0], *([1] * (q.ndim - 1))) + dq = q.float() * sf.view(*expand) + if dq.ndim == 2: + dq = _hadamard_rotate(dq) + out[name] = dq.to(dtype=dtype).contiguous() + else: + sv = float(s.item()) + out[name] = (q.float() * sv).to(dtype).contiguous() else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t + out[name] = w[name].detach().to("cpu").contiguous() return out -# ----------------------------- -# DATA LOADING -# ----------------------------- - def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. - def __init__(self, pattern: str): + def __init__(self, pattern: str) -> None: self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: raise FileNotFoundError(f"No files found for pattern: {pattern}") @@ -475,16 +449,18 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + def __init__( + self, pattern: str, rank: int, world_size: int, device: torch.device, + ) -> None: self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) + def next_batch( + self, global_tokens: int, seq_len: int, ga_steps: int, + ) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * ga_steps) per_rank_span = local_tokens + 1 chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span @@ -493,12 +469,9 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): + def __init__(self, eps: float | None = None) -> None: super().__init__() self.eps = eps @@ -507,49 +480,83 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) + def forward(self, input: Tensor) -> Tensor: + w = self.weight.to(input.dtype) + return F.linear(input, w, self.bias.to(input.dtype) if self.bias is not None else None) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + is_ctrl = any(p in name for p in CTRL_PATTERNS) + if (param.ndim < 2 or is_ctrl) and param.dtype != torch.float32: param.data = param.data.float() +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int) -> None: + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) + nn.init.normal_(self.embed.weight, std=0.005) + + def forward(self, input_ids: Tensor) -> Tensor: + return self.proj(self.embed(input_ids)) + + +class BigramHash(nn.Module): + def __init__( + self, num_buckets: int, model_dim: int, inner_dim: int = 128, + ) -> None: + super().__init__() + self.num_buckets = num_buckets + self.emb = nn.Embedding(num_buckets, inner_dim) + self.proj = CastedLinear(inner_dim, model_dim, bias=False) + nn.init.normal_(self.emb.weight, std=0.005) + + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0) + idx = (prev * 36313 + input_ids * 27191) % self.num_buckets + return self.proj(self.emb(idx)) + + class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__( + self, dim: int, base: float = 10000.0, rope_dims: int = 0, + ) -> None: super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + rd = rope_dims if rope_dims > 0 else dim + freqs = torch.arange(0, rd, 2, dtype=torch.float32) / rd + inv_freq = 1.0 / (base ** freqs) self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None + self.seq_len_cached = 0 + self.cos_cached: Tensor | None = None + self.sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + def forward( + self, seq_len: int, device: torch.device, dtype: torch.dtype, + ) -> tuple[Tensor, Tensor]: if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device + self.cos_cached is None + or self.sin_cached is None + or self.seq_len_cached != seq_len + or self.cos_cached.device != device ): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + self.cos_cached = freqs.cos()[None, None, :, :] + self.sin_cached = freqs.sin()[None, None, :, :] + self.seq_len_cached = seq_len + return self.cos_cached.to(dtype=dtype), self.sin_cached.to(dtype=dtype) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + rd = cos.size(-1) * 2 + x_rope = x[..., :rd] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x[..., rd:]), dim=-1) class CausalSelfAttention(nn.Module): @@ -560,57 +567,83 @@ def __init__( num_kv_heads: int, rope_base: float, qk_gain_init: float, - ): + use_xsa: bool = False, + rope_dims: int = 0, + ) -> None: super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") + assert dim % num_heads == 0 and num_heads % num_kv_heads == 0 self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") + self._use_xsa = use_xsa + assert self.head_dim % 2 == 0 kv_dim = self.num_kv_heads * self.head_dim self.c_q = CastedLinear(dim, dim, bias=False) self.c_k = CastedLinear(dim, kv_dim, bias=False) self.c_v = CastedLinear(dim, kv_dim, bias=False) self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) + self.proj.zero_init = True + self.proj.output_proj = True + self.q_gain = nn.Parameter(torch.full( + (num_heads,), qk_gain_init, dtype=torch.float32, + )) + self.rotary = Rotary( + self.head_dim, base=rope_base, rope_dims=rope_dims, + ) + self.attn_gate = CastedLinear(dim, num_heads, bias=False) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, x: Tensor, ve_out: Tensor | None = None, + ) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + if ve_out is not None: + v = v + ve_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + v_orig = v + if _fa3_func is not None: + qt = q.transpose(1, 2) + kt = k.transpose(1, 2) + vt = v.transpose(1, 2) + y = _fa3_func(qt, kt, vt, causal=True).transpose(1, 2) + else: + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + gate = torch.sigmoid(self.attn_gate(x)).transpose(1, 2).unsqueeze(-1) + y = y * gate + if self._use_xsa: + y = y.transpose(1, 2) # [B, S, H, D] + v_t = v_orig.transpose(1, 2) # [B, S, Hkv, D] + group_size = self.num_heads // self.num_kv_heads + y_grouped = y.reshape(bsz, seqlen, self.num_kv_heads, group_size, self.head_dim) + vn = F.normalize(v_t, dim=-1).unsqueeze(3) # [B, S, Hkv, 1, D] + dot = (y_grouped * vn).sum(-1, keepdim=True) + y = (y_grouped - dot * vn).reshape(bsz, seqlen, -1) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): + def __init__( + self, dim: int, mlp_mult: int, mlp_hidden: int = 0, + ) -> None: super().__init__() - hidden = mlp_mult * dim + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True + self.proj.zero_init = True + self.proj.output_proj = True def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.fc(x)) @@ -626,22 +659,42 @@ def __init__( mlp_mult: int, rope_base: float, qk_gain_init: float, - ): + use_xsa: bool = False, + mlp_hidden: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + rope_dims: int = 0, + ) -> None: super().__init__() + self.ln_scale_factor = ( + 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + ) self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, + qk_gain_init, use_xsa=use_xsa, rope_dims=rope_dims, + ) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter( + torch.ones(dim, dtype=torch.float32), + ) + self.mlp_scale = nn.Parameter( + torch.ones(dim, dtype=torch.float32), + ) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float(), + ) - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + def forward( + self, x: Tensor, x0: Tensor, ve_out: Tensor | None = None, + ) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, ve_out=ve_out) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) return x @@ -654,79 +707,194 @@ def __init__( num_heads: int, num_kv_heads: int, mlp_mult: int, - tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, - ): + ) -> None: super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings + assert logit_softcap > 0.0 self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.num_layers_actual = num_layers self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHash(2048, model_dim, inner_dim=128) + self.smear_gate = nn.Parameter(torch.zeros(model_dim)) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] + n_skip = min(self.num_encoder_layers, self.num_decoder_layers) + self.num_skip_weights = n_skip + self.skip_weights = nn.Parameter( + torch.ones(n_skip, model_dim, dtype=torch.float32), + ) + kv_dim = num_kv_heads * (model_dim // num_heads) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_str = os.environ.get("VE_LAYERS", "9,10") + self.ve_layer_set = set( + int(x) for x in ve_str.split(",") if x ) + if self.ve_layer_set: + self.ve_shared = ValueEmbedding( + vocab_size, ve_dim, kv_dim, + ) + self.ve_layer_scales = nn.ParameterList([ + nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + for _ in self.ve_layer_set + ]) + n_xsa = int(os.environ.get("XSA_LAYERS", 3)) + xsa_set = ( + set(range(num_layers - n_xsa, num_layers)) + if n_xsa > 0 else set() + ) + mlp_h = int(os.environ.get("MLP_HIDDEN", 0)) + ln_sc = bool(int(os.environ.get("LN_SCALE", "1"))) + rope_d = int(os.environ.get("ROPE_DIMS", 16)) + self.blocks = nn.ModuleList([ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, + rope_base, qk_gain_init, + use_xsa=(i in xsa_set), + mlp_hidden=mlp_h, layer_idx=i, + ln_scale=ln_sc, rope_dims=rope_d, + ) + for i in range(num_layers) + ]) self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True self._init_weights() def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) + if isinstance(module, nn.Linear): + if getattr(module, "zero_init", False): + nn.init.zeros_(module.weight) + else: + nn.init.orthogonal_(module.weight, gain=1.0) + if getattr(module, "output_proj", False): + with torch.no_grad(): + module.weight.mul_(1.0 / (2 * self.num_layers_actual) ** 0.5) + + def _ve_for_layer( + self, + layer: int, + ve_base: Tensor | None, + ve_sorted: list[int], + dtype: torch.dtype, + ) -> Tensor | None: + if ve_base is None or layer not in self.ve_layer_set: + return None + idx = ve_sorted.index(layer) + scale = self.ve_layer_scales[idx].to(dtype=dtype) + return ve_base * scale + + def _embed(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + x = self.tok_emb(input_ids) + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) - x0 = x + g = torch.sigmoid(self.smear_gate.to(x.dtype)) + x = x + g * (F.pad(x[:, :-1], (0, 0, 1, 0)) - x) + return x, x + + def _run_blocks( + self, x: Tensor, x0: Tensor, input_ids: Tensor, + ) -> Tensor: + ve_base = ( + self.ve_shared(input_ids) if self.ve_layer_set else None + ) + ve_sorted = sorted(self.ve_layer_set) if self.ve_layer_set else [] skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + ve = self._ve_for_layer(i, ve_base, ve_sorted, x.dtype) + x = self.blocks[i](x, x0, ve_out=ve) skips.append(x) for i in range(self.num_decoder_layers): + li = self.num_encoder_layers + i if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + sw = self.skip_weights[i].to(dtype=x.dtype) + x = x + sw[None, None, :] * skips.pop() + ve = self._ve_for_layer(li, ve_base, ve_sorted, x.dtype) + x = self.blocks[li](x, x0, ve_out=ve) + return self.final_norm(x) + + def forward( + self, input_ids: Tensor, target_ids: Tensor, + ) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, input_ids) + x = x.reshape(-1, x.size(-1)) + w = self.tok_emb.weight.to(x.dtype) + logits = self.logit_softcap * torch.tanh( + F.linear(x, w) / self.logit_softcap, + ) + return F.cross_entropy(logits.float(), target_ids.reshape(-1)) + + def per_token_loss( + self, input_ids: Tensor, target_ids: Tensor, + ) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, input_ids) + w = self.tok_emb.weight.to(x.dtype) + logits = self.logit_softcap * torch.tanh( + F.linear(x, w) / self.logit_softcap, + ) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="none", + ) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") +def eval_val_sliding( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + bytes_lut: Tensor, + space_lut: Tensor, + boundary_lut: Tensor, + stride: int = 64, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + all_starts = list(range(0, total_tokens - seq_len + 1, stride)) + rank_starts = all_starts[rank::world_size] + batch_size = max(1, min(args.val_batch_size // seq_len, 64)) + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + with torch.inference_mode(): + for i in range(0, len(rank_starts), batch_size): + batch_starts = rank_starts[i:i + batch_size] + bsz = len(batch_starts) + x = torch.stack( + [val_tokens[s:s + seq_len] for s in batch_starts], + ).to(device, torch.int64) + y = torch.stack( + [val_tokens[s + 1:s + seq_len + 1] for s in batch_starts], + ).to(device, torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + losses = base_model.per_token_loss(x, y).view(bsz, seq_len) + for j, s in enumerate(batch_starts): + score_from = 0 if s == 0 else (seq_len - stride) + scored = losses[j, score_from:] + val_loss_sum += scored.to(torch.float64).sum() + val_token_count += scored.numel() + tgt = y[j, score_from:] + prev = x[j, score_from:] + tbytes = bytes_lut[tgt].to(torch.int16) + tbytes += (space_lut[tgt] & ~boundary_lut[prev]).to(torch.int16) + val_byte_count += tbytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -# ----------------------------- -# TRAINING -# ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 @@ -735,93 +903,73 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") + assert world_size > 0 and 8 % world_size == 0, f"Invalid WORLD_SIZE={world_size}" + ga_steps = 8 // world_size + grad_scale = 1.0 / ga_steps + assert torch.cuda.is_available(), "CUDA required" device = torch.device("cuda", local_rank) torch.cuda.set_device(device) if distributed: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() - master_process = rank == 0 + is_main = rank == 0 - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) + from torch.backends.cuda import ( + enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp, + ) + use_cudnn = bool(int(os.environ.get("USE_CUDNN_SDPA", "1"))) + enable_cudnn_sdp(use_cudnn) + enable_flash_sdp(not use_cudnn) enable_mem_efficient_sdp(False) enable_math_sdp(False) logfile = None - if master_process: + if is_main: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) def log0(msg: str, console: bool = True) -> None: - if not master_process: + if not is_main: return if console: - print(msg) + print(msg, flush=True) if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) log0(code, console=False) log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- + log0(f"Python {sys.version} PyTorch {torch.__version__}", console=False) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + assert args.tokenizer_path.endswith(".model"), "Requires .model tokenizer" sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) + assert int(sp.vocab_size()) == args.vocab_size, ( + f"Vocab mismatch: {args.vocab_size} vs {sp.vocab_size()}" + ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + bytes_lut, space_lut, boundary_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- + n_val = val_tokens.numel() - 1 + log0( + f"data: {dataset_dir.name} " + f"train_shards:{actual_train_files} val_tokens:{n_val}" + ) base_model = GPT( vocab_size=args.vocab_size, @@ -830,7 +978,6 @@ def log0(msg: str, console: bool = True) -> None: num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, rope_base=args.rope_base, @@ -840,78 +987,83 @@ def log0(msg: str, console: bool = True) -> None: if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ + attn_backend = "FA3" if _fa3_func is not None else "SDPA" + log0( + f"torch.compile: fullgraph=True " + f"attention={attn_backend}+XSA mlp=relu2_3x" + ) + cmodel = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = ( + DDP(cmodel, device_ids=[local_rank], broadcast_buffers=False) + if distributed else cmodel + ) + + ema_state: dict[str, Tensor] | None = None + if args.ema_decay > 0: + ema_state = { + n: t.detach().float().clone() + for n, t in base_model.state_dict().items() + } + log0(f"EMA: initialized (decay={args.ema_decay})") + + bnp = list(base_model.blocks.named_parameters()) + mat_params = [ p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + for name, p in bnp + if p.ndim == 2 and not any(pattern in name for pattern in CTRL_PATTERNS) ] - scalar_params = [ + sc_params = [ p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + for name, p in bnp + if p.ndim < 2 or any(pattern in name for pattern in CTRL_PATTERNS) ] if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( + sc_params.append(base_model.skip_weights) + sc_params.append(base_model.smear_gate) + for p in base_model.bigram.parameters(): + sc_params.append(p) + if base_model.ve_layer_set: + for p in base_model.ve_shared.parameters(): + sc_params.append(p) + for p in base_model.ve_layer_scales.parameters(): + sc_params.append(p) + token_lr = args.tied_embed_lr + opt_tok = torch.optim.AdamW( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_weight_decay, fused=True, ) - optimizer_muon = Muon( - matrix_params, + opt_muon = Muon( + mat_params, lr=args.matrix_lr, - momentum=args.muon_momentum, + momentum=args.mu_mom, backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, ) - for group in optimizer_muon.param_groups: + for group in opt_muon.param_groups: group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + opt_scalar = torch.optim.AdamW( + [{"params": sc_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_weight_decay, fused=True, ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) + optimizers: list[torch.optim.Optimizer] = [opt_tok, opt_muon, opt_scalar] n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"params:{n_params} world:{world_size} accum:{ga_steps}") log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + f"lr: embed={token_lr} matrix={args.matrix_lr} " + f"scalar={args.scalar_lr}" ) log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + f"batch:{args.train_batch_tokens} seq:{args.train_seq_len} " + f"iters:{args.iterations} warmup:{args.warmup_steps} " + f"wall:{args.max_wallclock_seconds:.0f}s" ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) @@ -919,109 +1071,139 @@ def zero_grad_all() -> None: for opt in optimizers: opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + wall_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: + if wd_iters <= 0: + return 1.0 + if wall_ms is None: + wd_start = max(args.iterations - wd_iters, 0) + if wd_start <= step < args.iterations: + progress = 1.0 - (args.iterations - step) / max(wd_iters, 1) + if args.cosine_warmdown: + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 - progress return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + wd_ms = wd_iters * step_ms + rem_ms = max(wall_ms - elapsed_ms, 0.0) + if rem_ms <= wd_ms: + linear = rem_ms / max(wd_ms, 1e-9) + if args.cosine_warmdown: + progress = 1.0 - linear + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return linear + return 1.0 + + wd_iters = args.warmdown_iters - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + init_sd = { + n: t.detach().cpu().clone() + for n, t in base_model.state_dict().items() + } + init_opts = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] model.train() + torch.cuda.synchronize() for warmup_step in range(args.warmup_steps): zero_grad_all() - for micro_step in range(grad_accum_steps): + for micro_step in range(ga_steps): if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + model.require_backward_grad_sync = micro_step == ga_steps - 1 + x, y = train_loader.next_batch( + args.train_batch_tokens, args.train_seq_len, ga_steps, + ) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): warmup_loss = model(x, y) (warmup_loss * grad_scale).backward() for opt in optimizers: opt.step() zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + should_log_warmup = ( + args.warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == args.warmup_steps + ) + if should_log_warmup: log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + torch.cuda.synchronize() + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if name in init_sd: + param.data.copy_(init_sd[name].to(param.device, dtype=param.dtype)) + for name, buf in base_model.named_buffers(): + if name in init_sd: + buf.data.copy_(init_sd[name].to(buf.device, dtype=buf.dtype)) + for opt, state in zip(optimizers, init_opts, strict=True): opt.load_state_dict(state) zero_grad_all() if distributed: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - training_time_ms = 0.0 - stop_after_step: int | None = None + train_ms = 0.0 + stop_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() step = 0 while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + last_step = step == args.iterations or (stop_step is not None and step >= stop_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + should_validate = ( + last_step + or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + ) if should_validate: torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) + train_ms += 1000.0 * (time.perf_counter() - t0) val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, ) + for mod in base_model.modules(): + if isinstance(mod, Rotary): + mod.seq_len_cached = 0 log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + f"train_time:{train_ms:.0f}ms step_avg:{train_ms / max(step, 1):.2f}ms" ) torch.cuda.synchronize() t0 = time.perf_counter() if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) + if stop_step is not None and step < args.iterations: + log0(f"stopping_early: step:{step}/{args.iterations} time:{train_ms:.0f}ms") break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + elapsed_ms = train_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): + for micro_step in range(ga_steps): if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + model.require_backward_grad_sync = micro_step == ga_steps - 1 + x, y = train_loader.next_batch( + args.train_batch_tokens, args.train_seq_len, ga_steps, + ) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) train_loss += loss.detach() (loss * grad_scale).backward() - train_loss /= grad_accum_steps + train_loss /= ga_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum + frac = ( + min(step / args.mu_mom_warmup_steps, 1.0) + if args.mu_mom_warmup_steps > 0 else 1.0 + ) + mu_mom = ( + (1 - frac) * args.mu_mom_warmup_start + frac * args.mu_mom + ) + for group in opt_muon.param_groups: + group["momentum"] = mu_mom for opt in optimizers: for group in opt.param_groups: @@ -1033,90 +1215,147 @@ def lr_mul(step: int, elapsed_ms: float) -> float: opt.step() zero_grad_all() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + elapsed = train_ms + 1000.0 * (time.perf_counter() - t0) should_log_train = ( args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + and (step <= 10 or step % args.train_log_every == 0 or stop_step is not None) ) if should_log_train: log0( f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + f"train_time:{elapsed:.0f}ms step_avg:{elapsed / step:.2f}ms" ) - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: + reached_cap = wall_ms is not None and elapsed >= wall_ms + if distributed and wall_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step + if stop_step is None and reached_cap: + stop_step = step + + alloc_mb = torch.cuda.max_memory_allocated() // 1024 // 1024 + resv_mb = torch.cuda.max_memory_reserved() // 1024 // 1024 + log0(f"peak_mem: {alloc_mb}MiB alloc {resv_mb}MiB reserved") + + if ema_state is not None: + log0(f"EMA: loading averaged weights (decay={args.ema_decay})") + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if name in ema_state: + param.data.copy_(ema_state[name].to(param.device, dtype=param.dtype)) + for name, buf in base_model.named_buffers(): + if name in ema_state: + buf.data.copy_(ema_state[name].to(buf.device, dtype=buf.dtype)) + ema_state = None + if distributed: + dist.barrier() - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + torch.cuda.synchronize() + pre_q_loss, pre_q_bpb = eval_val( + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, ) + log0(f"pre_quant val_loss:{pre_q_loss:.4f} val_bpb:{pre_q_bpb:.4f}") + if pre_q_bpb > args.best_known_bpb + args.regression_threshold: + log0( + f"REGRESSION WARNING: pre_quant {pre_q_bpb:.4f} > " + f"best_known {args.best_known_bpb:.4f} + " + f"threshold {args.regression_threshold}" + ) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_obj, _ = quantize_state_dict_int6(base_model.state_dict()) + artifact_name = f"final_model.int6.{_COMPRESS_EXT}" quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: + quant_blob = _compress(quant_buf.getvalue()) + qf_bytes = 0 + if is_main: + with open(artifact_name, "wb") as f: f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") + qf_bytes = os.path.getsize(artifact_name) code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + total = qf_bytes + code_bytes + payload_bytes = len(quant_buf.getvalue()) + ratio = payload_bytes / qf_bytes if qf_bytes > 0 else 0 log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + f"artifact: {qf_bytes}B model + {code_bytes}B code " + f"= {total}B (compression:{ratio:.2f}x)" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if ratio < 1.6: + log0( + f"COMPRESSION ALERT: {ratio:.2f}x below 1.6x " + f"-- weight distributions may be degraded" + ) + if total > 16_000_000: + log0(f"FATAL: exceeds 16MB by {total - 16_000_000}B") + if distributed: + dist.destroy_process_group() + sys.exit(0) + else: + log0(f"headroom: {16_000_000 - total}B") if distributed: dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + with open(artifact_name, "rb") as f: + qblob = f.read() + raw = _decompress(qblob) + quant_state = torch.load(io.BytesIO(raw), map_location="cpu") + dequant_sd = dequantize_state_dict_int6(quant_state) + model_sd = base_model.state_dict() + missing = [k for k in model_sd if k not in dequant_sd] + if missing: + for k in missing: + dequant_sd[k] = model_sd[k].cpu() + base_model.load_state_dict(dequant_sd, strict=False) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + args, model, rank, world_size, device, ga_steps, + val_tokens, bytes_lut, space_lut, boundary_lut, ) torch.cuda.synchronize() + q_eval_ms = 1000.0 * (time.perf_counter() - t_qeval) log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + f"final_int6_{_COMPRESS_EXT}_roundtrip " + f"val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{q_eval_ms:.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + sw_val_bpb = q_val_bpb + if args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, bytes_lut, space_lut, boundary_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + sw_ms = 1000.0 * (time.perf_counter() - t_sw) + log0( + f"final_int6_{_COMPRESS_EXT}_sliding " + f"val_loss:{sw_val_loss:.4f} " + f"val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{sw_ms:.0f}ms" + ) + + if is_main: + avg_ms = train_ms / max(step, 1) + log0( + f"RUN_SUMMARY: steps={step} " + f"pre_quant_bpb={pre_q_bpb:.4f} " + f"post_quant_bpb={q_val_bpb:.4f} " + f"sliding_bpb={sw_val_bpb:.4f} " + f"artifact_bytes={qf_bytes} step_ms={avg_ms:.1f}" + ) if distributed: dist.destroy_process_group()