diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/README.md b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/README.md new file mode 100644 index 0000000000..08c71725cc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/README.md @@ -0,0 +1,111 @@ +# 11L Low-Rank Q192 + +## Results + +| Seed | Steps | step\_avg | val\_bpb (sliding) | Artifact Size | +|------|-------|----------|--------------------|---------------| +| 1337 | 7732 | 77.6ms | 1.1548 | 14,747,273 | +| 42 | 7821 | 77.1ms | 1.1552 | 14,939,593 | +| 113 | 7597 | 79.0ms | 1.1575 | 14,676,072 | +| **Mean** | — | — | **1.1558** | | + +All runs use clean compile cache (`rm -rf ~/.cache/torch/inductor_cache/`), zstd-22 compression. + +**vs official record** (SlidingWindow\_FP16Emb\_10L\_MuonWD\_OvertoneInit, val\_bpb=1.1748): +- Improvement: **-0.0190 bpb** / **-0.031 nats** +- One-sided t-test: t≈22, df=2, **p < 0.001** + +## How to run + +```bash +NUM_LAYERS=11 WEIGHT_DECAY=0.038 SEED=1337 torchrun --nproc_per_node=8 train_gpt.py +``` + +## Key Techniques + +1. **Low-rank Q factorization (r=192)**: Q projection factored as `c_q_down(512->192)` + `c_q_up(192->512)`. Q matrices have extreme condition numbers (100M+) and effective rank 89-114 out of 192, confirming rank 192 is sufficient. The factored representation is more int6-quantization-friendly: low-rank structure compresses better, reducing the fp32-to-int6 gap. + +2. **11 transformer layers** with encoder-decoder skip connections (5 encoder + 6 decoder), using parameter savings from low-rank Q. + +3. **Int6 per-row quantization + zstd-22 compression** for MLP and attention weights. Scalar/control parameters kept in fp32. + +4. **Sliding window evaluation** (stride=64) for final score. + +## Architecture + +- 11 layers, model\_dim=512, 8 attention heads, 4 KV heads (GQA) +- MLP 3x (hidden=1536), relu-squared activation +- Low-rank Q: `c_q_down (512->192)` + `c_q_up (192->512)` per layer +- Tied embeddings, vocab=1024, logit\_softcap=30 +- RoPE (base=10000), RMSNorm, residual mixing + +## Motivation + +Standard Q projection uses a full 512x512 matrix, but the Muon optimizer pushes all singular values toward 1, while Q naturally wants to operate in a lower-rank subspace (effective rank 89-114/192). Factoring Q to rank 192 makes this structure explicit, which has two benefits: (1) the model trains at the same per-step quality, and (2) the factored weights quantize significantly better under int6 per-row quantization because low-rank matrices have higher information density per element. + +## Experiments That Didn't Work + +The following ideas were explored through weight analysis and experiments but did not improve val\_bpb: + +**1. Legendre resid\_mix initialization** + +After training, the `resid_mix` parameters (mix0 and mix1) show a clear depth-dependent pattern: mix0 has a Block-0 outlier with a U-shape for deeper layers, and mix1 follows a strong U-shape with negative values in the middle (embedding subtraction). These patterns fit well to 4th-order Legendre polynomials. I tried initializing resid\_mix with integer Legendre coefficients at half-scale (interpolating between standard init and Legendre target) to give the optimizer a warm start. However, experiments showed no measurable improvement — Adam converges the scalar resid\_mix parameters to their targets within ~200 steps regardless of initialization. The Legendre shape is correct but the optimizer doesn't need help finding it. + +**2. Content-Dependent Pre-Rotation on Block 0's MLP** + +I introduced a content-dependent 2D rotation before each MLP's first linear layer (fully-connected, 512→1536): a small projection (`angle_proj`, 512→32, zero-init) computes 32 rotation angles from the input, then rotates 32 pairs of input dimensions before feeding into the fully-connected layer. This adds SwiGLU-like content-dependent feature mixing at only 1% parameter cost (16K params per layer) and without sacrificing MLP width — unlike SwiGLU which requires a third gate matrix (50% more MLP params, forcing MLP 3x down to 2x in a 16MB budget). + +Experiments confirmed the rotation is genuinely useful: the model learned large rotation angles (~146°, not small perturbations), used near-full effective rank (29-30/32 pairs independently active), and achieved higher per-step quality in later training compared to the baseline. Further analysis showed Block 0 learned the strongest rotation (row\_norm=2.55, focusing on raw embedding dimensions 359, 156, 423...) while deeper layers learned weaker rotation (row\_norm~1.1, sharing contextual dimensions 4, 12, 23...), suggesting Block 0 benefits most. However, `torch.compile` generates separate kernels for the rotation operations (cos/sin + concatenation), adding ~9 seconds of fixed compilation overhead — theoretically ~2ms of compute inflated to 9ms by graph-level inefficiency. I also tested a Block-0-only variant to reduce overhead, but the fixed compilation cost remained. In a 600-second budget, this overhead costs ~100 training steps, which negated the per-step quality gain. + +This remains a promising direction: content-dependent rotation provides norm-preserving, information-lossless feature mixing (det(R)=1) as a near-free alternative to gating mechanisms in parameter-constrained settings. The bottleneck is purely at the operator compilation level, not the method itself. + +**3. Depth-attention residual (AttnRes) architecture** + +Inspired by Moonshot AI's [Attention Residuals](https://arxiv.org/abs/2603.15031), I explored replacing the standard residual stream with a depth-attention mechanism: each layer's input is `emb + depth_attn(δ₀..δ_{i-1})` where `depth_attn` uses learned position bias (Legendre polynomials) and content-based routing over all previous layers' outputs. The motivation was (a) selective delta combination for better gradient flow, and (b) quantization error suppression (softmax weights sum to 1, reducing error accumulation). + +However, attention residual turns out to be counterproductive in small, dense models like this one. In Kimi-K2's MoE setting, attention residual helps route across sparsely-activated experts. In our dense 512-dim model, it actually suppresses the residual stream: softmax constrains routing weights to be non-negative and sum-to-1, but weight analysis of the baseline's `resid_mix` revealed the optimal depth routing requires negative weights (middle layers subtract embedding with mix1 ≈ -4) and non-normalized weights — patterns that softmax fundamentally cannot express. Additionally, unnormalized Values in depth attention caused block polarization where only 3 of 9 blocks remained active (67% of parameters wasted). The simple `resid_mix` mechanism (`h = mix0 * x + mix1 * x0` with unconstrained per-dim scalars) is strictly more expressive and naturally achieves the same quantization error reduction (Σ A\_i² = 2.20, 24% of standard residual) without any architectural overhead. + + +## Future Directions + +**Lower-rank Q (r=128 or adaptive per-layer rank)**: r=128 showed 94-97% energy capture but crossed the quality threshold. An adaptive scheme — wider rank in deep layers (where attn\_scale peaks) and narrower in shallow layers — could push further. + +**Better compilation for Pre-Rotation**: The content-dependent rotation achieved higher per-step quality but lost to compilation overhead. The core implementation is minimal: + +```python +class MLP(nn.Module): + def __init__(self, dim, mlp_mult, n_rot_pairs=32): + ... + self.angle_proj = CastedLinear(dim, n_rot_pairs, bias=False) # zero-init + + def forward(self, x): + r = self.n_rot_pairs + angles = self.angle_proj(x) # [B, T, 32] content-dependent angles + cos_a, sin_a = angles.cos(), angles.sin() + x1, x2 = x[..., :r], x[..., r:2*r] + x = torch.cat([ + x1 * cos_a + x2 * sin_a, # rotated pair 1 + -x1 * sin_a + x2 * cos_a, # rotated pair 2 + x[..., 2*r:] # unchanged dims + ], dim=-1) + return self.proj(torch.relu(self.fc(x)).square()) +``` + +A custom Triton kernel fusing `angle_proj -> cos/sin -> rotation -> fc` into a single pass, or improvements in `torch.compile`'s handling of trigonometric operations within compiled graphs, would eliminate the 9-second overhead and make this technique viable. The method provides SwiGLU-like content-dependent feature mixing at 1% parameter cost with zero information loss (det(R)=1), making it particularly suited for parameter-constrained or inference-optimized settings. + +## Training Configuration + +``` +NUM_LAYERS=11 +WEIGHT_DECAY=0.038 +TRAIN_SEQ_LEN=1024 +TRAIN_BATCH_TOKENS=524288 +EVAL_STRIDE=64 +MLP_MULT=3.0 +MODEL_DIM=512 +MATRIX_LR=0.02 +SCALAR_LR=0.02 +TIED_EMBED_LR=0.03 +MUON_MOMENTUM=0.99 +WARMDOWN_ITERS=3000 +``` diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/submission.json b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/submission.json new file mode 100644 index 0000000000..7684ea7c15 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/submission.json @@ -0,0 +1,16 @@ +{ + "author": "JayCheng113", + "github_id": "JayCheng113", + "name": "11L Low-Rank Q192", + "blurb": "11-layer transformer with factored Q projection (rank 192). Low-rank Q structure improves int6 quantization friendliness. Mean sliding val_bpb=1.1558 across 3 seeds.", + "date": "2026-03-20", + "val_loss": 1.9498, + "val_bpb": 1.1548, + "pre_quant_val_loss": 1.9913, + "pre_quant_val_bpb": 1.1793, + "step_stop": 7732, + "wallclock_seconds": 600, + "bytes_total": 14747213, + "bytes_model_int6_zstd": 14688652, + "bytes_code": 58561 +} diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_gpt.py b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_gpt.py new file mode 100644 index 0000000000..e0c64b319f --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_gpt.py @@ -0,0 +1,1380 @@ +""" +train_gpt_submit.py — Submission: wider MLP + int6 quantization + sliding window eval. +Changes from baseline: MLP_MULT=3.0, int6 per-row on MLP+attn weights, zstd-22 compression, +sliding window eval at stride=256 with batched forward_logits. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard +except ImportError: + import subprocess, sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "zstandard"]) + import zstandard + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.environ.get("TRAIN_FILES", os.path.join(data_path, "fineweb_train_*.bin")) + val_files = os.environ.get("VAL_FILES", os.path.join(data_path, "fineweb_val_*.bin")) + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + # SWA: number of checkpoints to average during warmdown (0 = disabled) + swa_checkpoints = int(os.environ.get("SWA_CHECKPOINTS", 0)) + # Weight decay (applied to matrix params only) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group["weight_decay"] + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +MIXED_KEEP_FLOAT_PATTERNS = tuple( + pattern + for pattern in os.environ.get("MIXED_KEEP_FLOAT_PATTERNS", "").split(",") + if pattern +) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + q_rank = 192 # Low-rank Q factorization + self.c_q_down = CastedLinear(dim, q_rank, bias=False) + self.c_q_up = CastedLinear(q_rank, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q_up(self.c_q_down(x)).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else wlen - stride + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in MIXED_KEEP_FLOAT_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + # SWA state: collect checkpoints during warmdown for averaging + swa_states: list[dict[str, Tensor]] = [] + swa_next_collect_idx = 0 # which SWA checkpoint to collect next + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints at evenly-spaced intervals during warmdown + if args.swa_checkpoints > 0 and scale < 1.0 and swa_next_collect_idx < args.swa_checkpoints: + # Determine collection points: evenly spaced through warmdown + # scale goes from 1.0 -> 0.0 during warmdown; collect at scale thresholds + threshold = 1.0 - (swa_next_collect_idx + 1) / (args.swa_checkpoints + 1) + if scale <= threshold: + swa_states.append({k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()}) + swa_next_collect_idx += 1 + log0(f"SWA checkpoint {swa_next_collect_idx}/{args.swa_checkpoints} collected at step {step} (scale={scale:.4f})") + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # SWA: average collected checkpoints with final model + if args.swa_checkpoints > 0 and len(swa_states) > 0: + log0(f"SWA: averaging {len(swa_states)} checkpoints with final model") + final_sd = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + all_states = swa_states + [final_sd] + n = len(all_states) + averaged_sd = {} + for key in final_sd: + stacked = torch.stack([s[key].float() for s in all_states]) + averaged_sd[key] = stacked.mean(dim=0).to(final_sd[key].dtype) + base_model.load_state_dict(averaged_sd, strict=True) + log0(f"SWA: loaded averaged model from {n} snapshots") + del swa_states, all_states + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + # int6 mixed quantization + zstd-22 compression (exact diagnostic path) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+zstd: {quant_file_bytes} bytes") + log0(f"Total submission size int6+zstd: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed113.log b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed113.log new file mode 100644 index 0000000000..344a403b41 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed113.log @@ -0,0 +1,101 @@ +W0320 13:12:05.807000 393384 torch/distributed/run.py:803] +W0320 13:12:05.807000 393384 torch/distributed/run.py:803] ***************************************** +W0320 13:12:05.807000 393384 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 13:12:05.807000 393384 torch/distributed/run.py:803] ***************************************** +logs/a676fd00-bcd0-41ee-8a37-3b54b0fbda1e.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:25 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25780824 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:113 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9341 val_bpb:4.1068 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9342 train_time:52ms step_avg:51.87ms +step:2/20000 train_loss:12.0872 train_time:116ms step_avg:57.87ms +step:3/20000 train_loss:7.1554 train_time:187ms step_avg:62.32ms +step:4/20000 train_loss:6.3958 train_time:258ms step_avg:64.38ms +step:5/20000 train_loss:6.7339 train_time:357ms step_avg:71.48ms +step:6/20000 train_loss:7.3945 train_time:430ms step_avg:71.63ms +step:7/20000 train_loss:6.6739 train_time:503ms step_avg:71.92ms +step:8/20000 train_loss:6.4261 train_time:579ms step_avg:72.33ms +step:9/20000 train_loss:6.1323 train_time:659ms step_avg:73.22ms +step:10/20000 train_loss:6.0566 train_time:752ms step_avg:75.23ms +step:200/20000 train_loss:2.7485 train_time:15595ms step_avg:77.97ms +step:400/20000 train_loss:2.2761 train_time:31034ms step_avg:77.58ms +step:600/20000 train_loss:2.4928 train_time:45784ms step_avg:76.31ms +step:800/20000 train_loss:2.2434 train_time:60442ms step_avg:75.55ms +step:1000/20000 train_loss:2.3453 train_time:75745ms step_avg:75.75ms +step:1000/20000 val_loss:2.3010 val_bpb:1.3628 train_time:75800ms step_avg:75.80ms +step:1200/20000 train_loss:2.3584 train_time:91141ms step_avg:75.95ms +step:1400/20000 train_loss:2.4117 train_time:107395ms step_avg:76.71ms +step:1600/20000 train_loss:2.0888 train_time:123162ms step_avg:76.98ms +step:1800/20000 train_loss:2.1780 train_time:138391ms step_avg:76.88ms +step:2000/20000 train_loss:2.2199 train_time:155269ms step_avg:77.63ms +step:2000/20000 val_loss:2.2052 val_bpb:1.3061 train_time:155327ms step_avg:77.66ms +step:2200/20000 train_loss:2.0359 train_time:170828ms step_avg:77.65ms +step:2400/20000 train_loss:2.1602 train_time:187139ms step_avg:77.97ms +step:2600/20000 train_loss:2.3749 train_time:202388ms step_avg:77.84ms +step:2800/20000 train_loss:2.2009 train_time:217326ms step_avg:77.62ms +step:3000/20000 train_loss:2.1882 train_time:232154ms step_avg:77.38ms +step:3000/20000 val_loss:2.1563 val_bpb:1.2771 train_time:232185ms step_avg:77.40ms +step:3200/20000 train_loss:2.1505 train_time:247350ms step_avg:77.30ms +step:3400/20000 train_loss:2.1239 train_time:262890ms step_avg:77.32ms +step:3600/20000 train_loss:2.0761 train_time:279007ms step_avg:77.50ms +step:3800/20000 train_loss:2.1850 train_time:297109ms step_avg:78.19ms +step:4000/20000 train_loss:2.1287 train_time:315876ms step_avg:78.97ms +step:4000/20000 val_loss:2.1354 val_bpb:1.2647 train_time:315903ms step_avg:78.98ms +step:4200/20000 train_loss:2.1415 train_time:337779ms step_avg:80.42ms +step:4400/20000 train_loss:2.0821 train_time:355710ms step_avg:80.84ms +step:4600/20000 train_loss:1.9352 train_time:372008ms step_avg:80.87ms +step:4800/20000 train_loss:2.0866 train_time:388079ms step_avg:80.85ms +step:5000/20000 train_loss:2.0980 train_time:403919ms step_avg:80.78ms +step:5000/20000 val_loss:2.1083 val_bpb:1.2487 train_time:403955ms step_avg:80.79ms +step:5200/20000 train_loss:2.1898 train_time:419480ms step_avg:80.67ms +step:5400/20000 train_loss:2.2271 train_time:434172ms step_avg:80.40ms +step:5600/20000 train_loss:2.1093 train_time:448958ms step_avg:80.17ms +step:5800/20000 train_loss:2.1738 train_time:464410ms step_avg:80.07ms +step:6000/20000 train_loss:2.1385 train_time:479152ms step_avg:79.86ms +step:6000/20000 val_loss:2.0682 val_bpb:1.2249 train_time:479177ms step_avg:79.86ms +step:6200/20000 train_loss:2.0304 train_time:494605ms step_avg:79.78ms +step:6400/20000 train_loss:2.0243 train_time:509871ms step_avg:79.67ms +step:6600/20000 train_loss:1.9125 train_time:524563ms step_avg:79.48ms +step:6800/20000 train_loss:1.8548 train_time:539763ms step_avg:79.38ms +step:7000/20000 train_loss:2.0892 train_time:554543ms step_avg:79.22ms +step:7000/20000 val_loss:2.0205 val_bpb:1.1966 train_time:554571ms step_avg:79.22ms +step:7200/20000 train_loss:1.9187 train_time:570326ms step_avg:79.21ms +step:7400/20000 train_loss:2.0173 train_time:585123ms step_avg:79.07ms +step:7597/20000 val_loss:1.9934 val_bpb:1.1806 train_time:599712ms step_avg:78.94ms +stopping_early: wallclock_cap train_time:599712ms step:7597/20000 +peak memory allocated: 14126 MiB reserved: 14414 MiB +Serialized model: 102120077 bytes +Code size: 58621 bytes +Serialized model int6+zstd: 14617451 bytes +Total submission size int6+zstd: 14676072 bytes +final_int6_roundtrip val_loss:2.0127 val_bpb:1.1921 eval_time:36047ms +final_int6_roundtrip_exact val_loss:2.01274800 val_bpb:1.19206274 +final_int6_sliding_window val_loss:1.9543 val_bpb:1.1575 stride:64 eval_time:62314ms +final_int6_sliding_window_exact val_loss:1.95433511 val_bpb:1.15746881 diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed1337.log b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed1337.log new file mode 100644 index 0000000000..43caa0e0be --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed1337.log @@ -0,0 +1,102 @@ +W0320 12:22:20.021000 247704 torch/distributed/run.py:803] +W0320 12:22:20.021000 247704 torch/distributed/run.py:803] ***************************************** +W0320 12:22:20.021000 247704 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 12:22:20.021000 247704 torch/distributed/run.py:803] ***************************************** +logs/8b825a80-8d67-44e0-bd61-1577aa7d02a3.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:25 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25780824 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9357 train_time:53ms step_avg:52.73ms +step:2/20000 train_loss:11.9914 train_time:132ms step_avg:66.13ms +step:3/20000 train_loss:7.2170 train_time:203ms step_avg:67.64ms +step:4/20000 train_loss:6.4856 train_time:274ms step_avg:68.49ms +step:5/20000 train_loss:6.9281 train_time:346ms step_avg:69.14ms +step:6/20000 train_loss:7.5301 train_time:438ms step_avg:72.94ms +step:7/20000 train_loss:6.7127 train_time:511ms step_avg:72.99ms +step:8/20000 train_loss:6.4408 train_time:588ms step_avg:73.48ms +step:9/20000 train_loss:6.2927 train_time:657ms step_avg:73.01ms +step:10/20000 train_loss:6.1660 train_time:750ms step_avg:75.03ms +step:200/20000 train_loss:2.7541 train_time:15770ms step_avg:78.85ms +step:400/20000 train_loss:2.2747 train_time:31729ms step_avg:79.32ms +step:600/20000 train_loss:2.4850 train_time:46923ms step_avg:78.20ms +step:800/20000 train_loss:2.2448 train_time:61913ms step_avg:77.39ms +step:1000/20000 train_loss:2.3405 train_time:77374ms step_avg:77.37ms +step:1000/20000 val_loss:2.2974 val_bpb:1.3606 train_time:77401ms step_avg:77.40ms +step:1200/20000 train_loss:2.3548 train_time:92594ms step_avg:77.16ms +step:1400/20000 train_loss:2.4106 train_time:108454ms step_avg:77.47ms +step:1600/20000 train_loss:2.0832 train_time:124083ms step_avg:77.55ms +step:1800/20000 train_loss:2.1725 train_time:139098ms step_avg:77.28ms +step:2000/20000 train_loss:2.2185 train_time:154560ms step_avg:77.28ms +step:2000/20000 val_loss:2.2013 val_bpb:1.3037 train_time:154588ms step_avg:77.29ms +step:2200/20000 train_loss:2.0355 train_time:169444ms step_avg:77.02ms +step:2400/20000 train_loss:2.1564 train_time:185367ms step_avg:77.24ms +step:2600/20000 train_loss:2.3684 train_time:200395ms step_avg:77.08ms +step:2800/20000 train_loss:2.1986 train_time:215788ms step_avg:77.07ms +step:3000/20000 train_loss:2.1842 train_time:230789ms step_avg:76.93ms +step:3000/20000 val_loss:2.1529 val_bpb:1.2751 train_time:230815ms step_avg:76.94ms +step:3200/20000 train_loss:2.1507 train_time:246341ms step_avg:76.98ms +step:3400/20000 train_loss:2.1210 train_time:261595ms step_avg:76.94ms +step:3600/20000 train_loss:2.0764 train_time:277525ms step_avg:77.09ms +step:3800/20000 train_loss:2.1798 train_time:293834ms step_avg:77.32ms +step:4000/20000 train_loss:2.1234 train_time:309437ms step_avg:77.36ms +step:4000/20000 val_loss:2.1330 val_bpb:1.2633 train_time:309470ms step_avg:77.37ms +step:4200/20000 train_loss:2.1396 train_time:328794ms step_avg:78.28ms +step:4400/20000 train_loss:2.0819 train_time:344090ms step_avg:78.20ms +step:4600/20000 train_loss:1.9344 train_time:359440ms step_avg:78.14ms +step:4800/20000 train_loss:2.0949 train_time:374026ms step_avg:77.92ms +step:5000/20000 train_loss:2.1020 train_time:389444ms step_avg:77.89ms +step:5000/20000 val_loss:2.1135 val_bpb:1.2517 train_time:389557ms step_avg:77.91ms +step:5200/20000 train_loss:2.1902 train_time:404618ms step_avg:77.81ms +step:5400/20000 train_loss:2.2333 train_time:419390ms step_avg:77.66ms +step:5600/20000 train_loss:2.1116 train_time:434105ms step_avg:77.52ms +step:5800/20000 train_loss:2.1777 train_time:449330ms step_avg:77.47ms +step:6000/20000 train_loss:2.1472 train_time:464075ms step_avg:77.35ms +step:6000/20000 val_loss:2.0744 val_bpb:1.2286 train_time:464101ms step_avg:77.35ms +step:6200/20000 train_loss:2.0317 train_time:479978ms step_avg:77.42ms +step:6400/20000 train_loss:2.0279 train_time:495573ms step_avg:77.43ms +step:6600/20000 train_loss:1.9203 train_time:510214ms step_avg:77.31ms +step:6800/20000 train_loss:1.8615 train_time:525466ms step_avg:77.27ms +step:7000/20000 train_loss:2.0924 train_time:540543ms step_avg:77.22ms +step:7000/20000 val_loss:2.0275 val_bpb:1.2008 train_time:540571ms step_avg:77.22ms +step:7200/20000 train_loss:1.9248 train_time:556857ms step_avg:77.34ms +step:7400/20000 train_loss:2.0248 train_time:571807ms step_avg:77.27ms +step:7600/20000 train_loss:2.0251 train_time:587801ms step_avg:77.34ms +step:7732/20000 val_loss:1.9912 val_bpb:1.1793 train_time:599933ms step_avg:77.59ms +stopping_early: wallclock_cap train_time:599933ms step:7732/20000 +peak memory allocated: 14126 MiB reserved: 14412 MiB +Serialized model: 102120077 bytes +Code size: 58621 bytes +Serialized model int6+zstd: 14688652 bytes +Total submission size int6+zstd: 14747273 bytes +final_int6_roundtrip val_loss:2.0082 val_bpb:1.1894 eval_time:39761ms +final_int6_roundtrip_exact val_loss:2.00817276 val_bpb:1.18935303 +final_int6_sliding_window val_loss:1.9498 val_bpb:1.1548 stride:64 eval_time:64830ms +final_int6_sliding_window_exact val_loss:1.94977911 val_bpb:1.15477050 diff --git a/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed42.log b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed42.log new file mode 100644 index 0000000000..524ee74ebd --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_11L_LowRankQ192/train_seed42.log @@ -0,0 +1,103 @@ +W0320 12:38:27.016000 296322 torch/distributed/run.py:803] +W0320 12:38:27.016000 296322 torch/distributed/run.py:803] ***************************************** +W0320 12:38:27.016000 296322 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 12:38:27.016000 296322 torch/distributed/run.py:803] ***************************************** +logs/868a73ac-b6d0-435c-83d2-386a872b21fe.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:25 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25780824 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9376 val_bpb:4.1088 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9380 train_time:52ms step_avg:51.50ms +step:2/20000 train_loss:12.0555 train_time:116ms step_avg:57.80ms +step:3/20000 train_loss:7.1732 train_time:186ms step_avg:61.98ms +step:4/20000 train_loss:6.4499 train_time:260ms step_avg:64.91ms +step:5/20000 train_loss:6.8776 train_time:347ms step_avg:69.44ms +step:6/20000 train_loss:7.5138 train_time:423ms step_avg:70.42ms +step:7/20000 train_loss:6.7857 train_time:494ms step_avg:70.64ms +step:8/20000 train_loss:6.4578 train_time:567ms step_avg:70.91ms +step:9/20000 train_loss:6.2392 train_time:660ms step_avg:73.32ms +step:10/20000 train_loss:6.1404 train_time:792ms step_avg:79.22ms +step:200/20000 train_loss:2.7325 train_time:15391ms step_avg:76.95ms +step:400/20000 train_loss:2.2745 train_time:30264ms step_avg:75.66ms +step:600/20000 train_loss:2.4825 train_time:44781ms step_avg:74.64ms +step:800/20000 train_loss:2.2450 train_time:59261ms step_avg:74.08ms +step:1000/20000 train_loss:2.3403 train_time:74275ms step_avg:74.28ms +step:1000/20000 val_loss:2.2993 val_bpb:1.3618 train_time:74313ms step_avg:74.31ms +step:1200/20000 train_loss:2.3576 train_time:89199ms step_avg:74.33ms +step:1400/20000 train_loss:2.4108 train_time:105119ms step_avg:75.08ms +step:1600/20000 train_loss:2.0880 train_time:120415ms step_avg:75.26ms +step:1800/20000 train_loss:2.1808 train_time:135152ms step_avg:75.08ms +step:2000/20000 train_loss:2.2198 train_time:150441ms step_avg:75.22ms +step:2000/20000 val_loss:2.2028 val_bpb:1.3046 train_time:150467ms step_avg:75.23ms +step:2200/20000 train_loss:2.0362 train_time:165028ms step_avg:75.01ms +step:2400/20000 train_loss:2.1600 train_time:180263ms step_avg:75.11ms +step:2600/20000 train_loss:2.3667 train_time:195445ms step_avg:75.17ms +step:2800/20000 train_loss:2.1928 train_time:210437ms step_avg:75.16ms +step:3000/20000 train_loss:2.1888 train_time:225334ms step_avg:75.11ms +step:3000/20000 val_loss:2.1538 val_bpb:1.2756 train_time:225362ms step_avg:75.12ms +step:3200/20000 train_loss:2.1495 train_time:240392ms step_avg:75.12ms +step:3400/20000 train_loss:2.1234 train_time:255628ms step_avg:75.18ms +step:3600/20000 train_loss:2.0763 train_time:270912ms step_avg:75.25ms +step:3800/20000 train_loss:2.1809 train_time:286285ms step_avg:75.34ms +step:4000/20000 train_loss:2.1300 train_time:301993ms step_avg:75.50ms +step:4000/20000 val_loss:2.1337 val_bpb:1.2637 train_time:302021ms step_avg:75.51ms +step:4200/20000 train_loss:2.1408 train_time:321345ms step_avg:76.51ms +step:4400/20000 train_loss:2.0795 train_time:337500ms step_avg:76.70ms +step:4600/20000 train_loss:1.9294 train_time:352253ms step_avg:76.58ms +step:4800/20000 train_loss:2.1019 train_time:366658ms step_avg:76.39ms +step:5000/20000 train_loss:2.1077 train_time:381884ms step_avg:76.38ms +step:5000/20000 val_loss:2.1193 val_bpb:1.2552 train_time:381913ms step_avg:76.38ms +step:5200/20000 train_loss:2.1972 train_time:396744ms step_avg:76.30ms +step:5400/20000 train_loss:2.2427 train_time:411378ms step_avg:76.18ms +step:5600/20000 train_loss:2.1154 train_time:425930ms step_avg:76.06ms +step:5800/20000 train_loss:2.1850 train_time:440791ms step_avg:76.00ms +step:6000/20000 train_loss:2.1559 train_time:455789ms step_avg:75.96ms +step:6000/20000 val_loss:2.0807 val_bpb:1.2323 train_time:455816ms step_avg:75.97ms +step:6200/20000 train_loss:2.0371 train_time:472232ms step_avg:76.17ms +step:6400/20000 train_loss:2.0378 train_time:488773ms step_avg:76.37ms +step:6600/20000 train_loss:1.9219 train_time:505266ms step_avg:76.56ms +step:6800/20000 train_loss:1.8688 train_time:520901ms step_avg:76.60ms +step:7000/20000 train_loss:2.0993 train_time:536122ms step_avg:76.59ms +step:7000/20000 val_loss:2.0315 val_bpb:1.2032 train_time:536157ms step_avg:76.59ms +step:7200/20000 train_loss:1.9259 train_time:552252ms step_avg:76.70ms +step:7400/20000 train_loss:2.0249 train_time:567482ms step_avg:76.69ms +step:7600/20000 train_loss:2.0303 train_time:583293ms step_avg:76.75ms +step:7800/20000 train_loss:1.9332 train_time:598571ms step_avg:76.74ms +step:7821/20000 val_loss:1.9913 val_bpb:1.1793 train_time:603198ms step_avg:77.13ms +stopping_early: wallclock_cap train_time:603198ms step:7821/20000 +peak memory allocated: 14126 MiB reserved: 14414 MiB +Serialized model: 102120077 bytes +Code size: 58621 bytes +Serialized model int6+zstd: 14880972 bytes +Total submission size int6+zstd: 14939593 bytes +final_int6_roundtrip val_loss:2.0093 val_bpb:1.1900 eval_time:38834ms +final_int6_roundtrip_exact val_loss:2.00931211 val_bpb:1.19002781 +final_int6_sliding_window val_loss:1.9505 val_bpb:1.1552 stride:64 eval_time:62326ms +final_int6_sliding_window_exact val_loss:1.95048320 val_bpb:1.15518750