diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/README.md b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/README.md new file mode 100644 index 0000000000..bc673218f8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/README.md @@ -0,0 +1,57 @@ +# SP8192 + BOS-Fix SmearGate + LQER Asym + Phased TTT (10L) + +**val_bpb: 1.07171** (single seed 314) | **15.37 MB** | 8xH100 SXM, 596s + +Applies the full SOTA stack from PR #1851 (BOS-fixed SmearGate + LQER Asymmetric + Phased TTT + layer looping) with the SP8192 tokenizer instead of SP1024. Uses 10 transformer layers instead of 11 to fit the larger embedding table (8192 vocab) under the 16MB artifact limit with brotli compression. + +## Results + +| Metric | Value | +|---|---| +| Pre-quant val_bpb | 1.07399 | +| Post-GPTQ val_bpb | 1.08251 | +| Post-TTT val_bpb | 1.07171 | +| Artifact size | 15,373,365 bytes | +| Training steps | 5,218 | +| Training time | 596s | +| Hardware | 8xH100 80GB SXM | +| Seed | 314 | + +## Changes vs PR #1851 + +- **SP8192 tokenizer** instead of SP1024 — better tokenization efficiency, ~0.02 BPB improvement +- **10 layers** instead of 11 — required to fit under 16MB with the larger 8192-vocab embedding table +- All other settings identical to PR #1851: BOS-fixed SmearGate, GPTQ int6 + LQER Asymmetric, Phased TTT (1 phase, 2000 prefix docs), layer looping, SparseAttnGate + +## Architecture + +| Component | Setting | +|---|---| +| Layers | 10 (512d, 8 heads, 4 KV heads) | +| Vocabulary | SP8192 | +| Layer loop | encoder [0,1,2,3,4,5,3,4] decoder [5,3,4,5,6,7,8,9] | +| SmearGate | BOS-boundary fixed (PR #1851) | +| Quantization | GPTQ int6 + LQER Asymmetric | +| TTT | Phased score-first LoRA TTT (rank 96, 1 phase) | +| Compression | brotli quality=11 | + +## Run Command + +```bash +TORCHINDUCTOR_CACHE_DIR=/workspace/inductor_cache \ +RUN_ID=sota_sp8192_10L SEED=314 VOCAB_SIZE=8192 NUM_LAYERS=10 \ +SMEAR_GATE_ENABLED=1 SPARSE_ATTN_GATE_ENABLED=1 SPARSE_ATTN_GATE_SCALE=0.5 \ +LQER_ENABLED=1 LQER_ASYM_ENABLED=1 \ +MLP_CLIP_SIGMAS=11.5 EMBED_CLIP_SIGMAS=14.0 \ +WARMDOWN_FRAC=0.85 MIN_LR=0.1 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Lineage + +``` +PR #1851 (1.0614 BPB) — BOS-fix SmearGate + LQER Asym + Phased TTT + └── This work: + ├── SP8192 tokenizer (from sproos/parameter-golf-tokenizers) + └── 10 layers (down from 11 to fit 16MB with larger vocab embedding) +``` diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/final_model.int6.ptz b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/final_model.int6.ptz new file mode 100644 index 0000000000..7265546802 Binary files /dev/null and b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/final_model.int6.ptz differ diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/requirements.txt b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/requirements.txt new file mode 100644 index 0000000000..51dba8bde8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/requirements.txt @@ -0,0 +1,5 @@ +torch==2.9.1+cu128 +sentencepiece==0.2.1 +flash-attn==2.8.3 +numpy +brotli diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/submission.json b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/submission.json new file mode 100644 index 0000000000..116ac044c2 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/submission.json @@ -0,0 +1,25 @@ +{ + "author": "am-commits", + "github_id": "am-commits", + "name": "SP8192 + BOS-Fix SmearGate + LQER Asym + Phased TTT (10L)", + "blurb": "PR #1851 SOTA stack (BOS-fixed SmearGate + LQER Asymmetric + Phased TTT + layer looping) with SP8192 tokenizer. 10 layers instead of 11 to fit under 16MB with larger 8192-vocab embedding table.", + "date": "2026-04-30", + "track": "10min_16mb", + "val_bpb": 1.07170518, + "pre_quant_val_bpb": 1.07398921, + "step_stop": 5218, + "wallclock_seconds": 596, + "bytes_total": 15373365, + "gpu": "8xH100-80GB-SXM", + "seed": 314, + "vocab_size": 8192, + "num_layers": 10, + "comparison_baseline_pr": 1851, + "metrics": { + "pre_quant_post_ema_val_bpb": 1.07398921, + "post_quant_val_bpb": 1.08251175, + "post_ttt_val_bpb": 1.07170518, + "gptq_time_s": 10.0, + "ttt_eval_time_s": 487 + } +} diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_gpt.py b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_gpt.py new file mode 100644 index 0000000000..95c10c50dd --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_gpt.py @@ -0,0 +1,3575 @@ +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 1.0)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 10.0)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "0"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "0"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 1.0)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "0"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + # CaseOps: when enabled, load per-token byte sidecar and stash it as a + # CPU tensor aligned 1:1 with self.val_tokens. eval_val/eval_val_ttt + # branches use this as the canonical raw-byte budget per token. + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.5 * c0) + aux1 = tl.where(c1 > 0, c1, 0.5 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). Inline gate compute with .contiguous() on the slice fed + # to the projection so torch.compile fullgraph is happy. lam=0 + W=0 -> identity + # at init. This block runs unconditionally on the smear path; the cat keeps + # position 0 untouched so causality holds. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + bos_mask = (input_ids[:, 1:] == 1).unsqueeze(-1) + g = g.masked_fill(bos_mask, 0.0) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + return softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + bos_mask = (input_ids[:, 1:] == 1).unsqueeze(-1) + g = g.masked_fill(bos_mask, 0.0) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + lqer_cands = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + ret = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=clip_range + ) + q, s = ret + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + if lqer_on: + W_q = q.float() * s.float().view(-1, 1) + E = t.float() - W_q + lqer_cands[name] = (E, float(E.norm())) + if lqer_on and lqer_cands: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def _unbank_state_dict(state_dict, num_layers): + sd = {} + n = num_layers + for k, v in state_dict.items(): + t = v.detach().cpu() if v is not None else None + if k == "qo_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_q.weight"] = t[i] + sd[f"blocks.{i}.attn.proj.weight"] = t[n + i] + elif k == "kv_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_k.weight"] = t[i] + sd[f"blocks.{i}.attn.c_v.weight"] = t[n + i] + elif k == "mlp_up_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.fc.weight"] = t[i] + elif k == "mlp_down_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.proj.weight"] = t[i] + else: + if t is not None: + sd[k] = t + return sd + + +def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): + sd = {} + n = num_layers + sd["qo_bank"] = torch.zeros(2 * n, model_dim, model_dim) + sd["kv_bank"] = torch.zeros(2 * n, kv_dim, model_dim) + for i in range(n): + sd["qo_bank"][i] = flat_sd[f"blocks.{i}.attn.c_q.weight"] + sd["qo_bank"][n + i] = flat_sd[f"blocks.{i}.attn.proj.weight"] + sd["kv_bank"][i] = flat_sd[f"blocks.{i}.attn.c_k.weight"] + sd["kv_bank"][n + i] = flat_sd[f"blocks.{i}.attn.c_v.weight"] + sd["mlp_up_bank"] = torch.zeros(n, hidden_dim, model_dim) + sd["mlp_down_bank"] = torch.zeros(n, model_dim, hidden_dim) + for i in range(n): + sd["mlp_up_bank"][i] = flat_sd[f"blocks.{i}.mlp.fc.weight"] + sd["mlp_down_bank"][i] = flat_sd[f"blocks.{i}.mlp.proj.weight"] + for k, v in flat_sd.items(): + if not ( + k.startswith("blocks.") + and any( + p in k + for p in [ + ".attn.c_q.", ".attn.c_k.", ".attn.c_v.", + ".attn.proj.", ".mlp.fc.", ".mlp.proj.", + ] + ) + ): + sd[k] = v + return sd + + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + minified = subprocess.run( + ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "-"], + input=code_raw, capture_output=True, check=True, + ).stdout + compressed = lzma.compress(minified) + encoded = base64.b85encode(compressed) + wrapper = b'import lzma as L,base64 as B\nexec(L.decompress(B.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + device = torch.device("cuda", h.local_rank) + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + log("GPTQ:collecting Hessians from calibration data...") + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + t_quant = time.perf_counter() + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + log(f"GPTQ:quantized in {time.perf_counter()-t_quant:.1f}s") + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + t_compress = time.perf_counter() + quant_blob = _compress(quant_raw, h.compressor) + log(f"GPTQ:compressed+saved in {time.perf_counter()-t_compress:.1f}s") + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + if val_data.caseops_enabled and val_data.val_bytes is not None: + # CaseOps: read per-token byte budget from sidecar at the same + # global positions as the target tokens y. raw_start/raw_end + # span [raw_start, raw_end), x = local[:-1], y = local[1:], + # so y is at sidecar positions [raw_start + 1, raw_end). + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + else: + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + if end - start < 2: + continue + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + with torch.no_grad(): + _accumulate_bpb( + per_tok_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.1f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + else: + t_artifact_start = time.perf_counter() + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + torch._dynamo.reset() + if os.environ.get("PREQUANT_ONLY", "0") == "1": + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + t_serialize_start = time.perf_counter() + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + serialize_seconds = time.perf_counter() - t_serialize_start + log(f"serialize_wallclock: {serialize_seconds:.3f}s") + if h.distributed: + dist.barrier() + artifact_production_seconds = time.perf_counter() - t_artifact_start + log(f"artifact_production_wallclock: {artifact_production_seconds:.3f}s (must be < {h.max_wallclock_seconds})") + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + _fwd_ttt_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 16 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_seed314.log b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_seed314.log new file mode 100644 index 0000000000..ccaee44f7c --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP8192_SOTAStack_10L_BOS-Fix_LQER_TTT/train_seed314.log @@ -0,0 +1,506 @@ +W0501 01:52:06.090000 696950 torch/distributed/run.py:803] +W0501 01:52:06.090000 696950 torch/distributed/run.py:803] ***************************************** +W0501 01:52:06.090000 696950 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 01:52:06.090000 696950 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + beta1: 0.9 + beta2: 0.95 + caseops_enabled: False + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: False + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/sota_sp8192_10L.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_rank: 4 + lqer_top_k: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 10 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: sota_sp8192_10L + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_bytes_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40546304 +model_params:33059929 +gptq:reserving 4.0s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0062 val_bpb: 3.4870 +1/20000 train_loss: 9.0069 train_time: 0.0m tok/s: 15154124 +2/20000 train_loss: 12.2588 train_time: 0.0m tok/s: 12832497 +3/20000 train_loss: 11.2956 train_time: 0.0m tok/s: 11401801 +4/20000 train_loss: 9.6881 train_time: 0.0m tok/s: 10704107 +5/20000 train_loss: 8.2825 train_time: 0.0m tok/s: 10387456 +500/20000 train_loss: 3.3067 train_time: 0.7m tok/s: 8891930 +1000/20000 train_loss: 3.1220 train_time: 1.5m tok/s: 8823179 +1500/20000 train_loss: 3.0980 train_time: 2.2m tok/s: 8803732 +2000/20000 train_loss: 3.0578 train_time: 3.0m tok/s: 8802017 +layer_loop:enabled step:2333 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9] +2500/20000 train_loss: 3.0717 train_time: 3.8m tok/s: 8513006 +3000/20000 train_loss: 2.8613 train_time: 5.0m tok/s: 7912869 +3500/20000 train_loss: 2.9179 train_time: 6.1m tok/s: 7536435 +4000/20000 train_loss: 2.9815 train_time: 7.2m tok/s: 7277031 +4000/20000 val_loss: 2.9069 val_bpb: 1.1255 +4500/20000 train_loss: 2.8217 train_time: 8.3m tok/s: 7086698 +5000/20000 train_loss: 2.6099 train_time: 9.4m tok/s: 6939256 +5218/20000 val_loss: 2.8037 val_bpb: 1.0855 +stopping_early: wallclock_cap train_time: 596060ms step: 5218/20000 +peak memory allocated: 39502 MiB reserved: 44890 MiB +ema:applying EMA weights +Serialized model: 123872824 bytes +Code size (uncompressed): 152678 bytes +Code size (compressed): 31390 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 61 Hessians in 3.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int6)+lqer_asym: blocks.mlp.fc.weight + gptq (int8)+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.attn_gate_w, blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda +GPTQ:quantized in 9.2s +GPTQ:compressed+saved in 60.7s +Serialized model quantized+brotli: 15373365 bytes +Total submission size quantized+brotli: 15404755 bytes +serialize_wallclock: 74.318s +artifact_production_wallclock: 1083.465s (must be < 600.0) +diagnostic pre-quantization post-ema val_loss:2.77386058 val_bpb:1.07398921 eval_time:6654ms +diagnostic quantized val_loss:2.79587230 val_bpb:1.08251175 eval_time:59592ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (146.7s) + +beginning TTT eval timer +ttt_phased: total_docs:49999 prefix_docs:2000 suffix_docs:47999 num_phases:1 boundaries:[2000] +ttp: b775/782 bl:2.6949 bb:1.0663 rl:2.6949 rb:1.0663 dl:5869-6355 gd:0 +ttp: b774/782 bl:2.7171 bb:1.0760 rl:2.7056 rb:1.0710 dl:5573-5864 gd:0 +ttp: b769/782 bl:2.7708 bb:1.1009 rl:2.7233 rb:1.0791 dl:4317-4485 gd:0 +ttp: b763/782 bl:2.7696 bb:1.0875 rl:2.7317 rb:1.0806 dl:3540-3634 gd:0 +ttp: b758/782 bl:2.8486 bb:1.0867 rl:2.7477 rb:1.0815 dl:3111-3185 gd:0 +ttp: b751/782 bl:2.7795 bb:1.0803 rl:2.7511 rb:1.0813 dl:2694-2741 gd:0 +ttp: b746/782 bl:2.6715 bb:1.0549 rl:2.7441 rb:1.0790 dl:2458-2500 gd:0 +ttpp: phase:1/1 pd:2511 gd:2000 t:284.5s +tttg: c1/207 lr:0.001000 t:2.0s +tttg: c2/207 lr:0.001000 t:2.0s +tttg: c3/207 lr:0.001000 t:2.1s +tttg: c4/207 lr:0.000999 t:2.2s +tttg: c5/207 lr:0.000999 t:2.3s +tttg: c6/207 lr:0.000999 t:2.3s +tttg: c7/207 lr:0.000998 t:2.4s +tttg: c8/207 lr:0.000997 t:2.5s +tttg: c9/207 lr:0.000996 t:2.6s +tttg: c10/207 lr:0.000995 t:2.6s +tttg: c11/207 lr:0.000994 t:2.7s +tttg: c12/207 lr:0.000993 t:2.8s +tttg: c13/207 lr:0.000992 t:2.9s +tttg: c14/207 lr:0.000990 t:2.9s +tttg: c15/207 lr:0.000989 t:3.0s +tttg: c16/207 lr:0.000987 t:3.1s +tttg: c17/207 lr:0.000985 t:3.2s +tttg: c18/207 lr:0.000983 t:3.2s +tttg: c19/207 lr:0.000981 t:3.3s +tttg: c20/207 lr:0.000979 t:3.4s +tttg: c21/207 lr:0.000977 t:3.5s +tttg: c22/207 lr:0.000975 t:3.5s +tttg: c23/207 lr:0.000972 t:3.6s +tttg: c24/207 lr:0.000970 t:3.7s +tttg: c25/207 lr:0.000967 t:3.8s +tttg: c26/207 lr:0.000964 t:3.8s +tttg: c27/207 lr:0.000961 t:3.9s +tttg: c28/207 lr:0.000958 t:4.0s +tttg: c29/207 lr:0.000955 t:4.1s +tttg: c30/207 lr:0.000952 t:4.1s +tttg: c31/207 lr:0.000949 t:4.2s +tttg: c32/207 lr:0.000945 t:4.3s +tttg: c33/207 lr:0.000942 t:4.4s +tttg: c34/207 lr:0.000938 t:4.4s +tttg: c35/207 lr:0.000934 t:4.5s +tttg: c36/207 lr:0.000930 t:4.6s +tttg: c37/207 lr:0.000927 t:4.7s +tttg: c38/207 lr:0.000922 t:4.7s +tttg: c39/207 lr:0.000918 t:4.8s +tttg: c40/207 lr:0.000914 t:4.9s +tttg: c41/207 lr:0.000910 t:5.0s +tttg: c42/207 lr:0.000905 t:5.0s +tttg: c43/207 lr:0.000901 t:5.1s +tttg: c44/207 lr:0.000896 t:5.2s +tttg: c45/207 lr:0.000892 t:5.3s +tttg: c46/207 lr:0.000887 t:5.3s +tttg: c47/207 lr:0.000882 t:5.4s +tttg: c48/207 lr:0.000877 t:5.5s +tttg: c49/207 lr:0.000872 t:5.6s +tttg: c50/207 lr:0.000867 t:5.6s +tttg: c51/207 lr:0.000862 t:5.7s +tttg: c52/207 lr:0.000856 t:5.8s +tttg: c53/207 lr:0.000851 t:5.9s +tttg: c54/207 lr:0.000845 t:5.9s +tttg: c55/207 lr:0.000840 t:6.0s +tttg: c56/207 lr:0.000834 t:6.1s +tttg: c57/207 lr:0.000828 t:6.2s +tttg: c58/207 lr:0.000823 t:6.2s +tttg: c59/207 lr:0.000817 t:6.3s +tttg: c60/207 lr:0.000811 t:6.4s +tttg: c61/207 lr:0.000805 t:6.5s +tttg: c62/207 lr:0.000799 t:6.5s +tttg: c63/207 lr:0.000793 t:6.6s +tttg: c64/207 lr:0.000786 t:6.7s +tttg: c65/207 lr:0.000780 t:6.8s +tttg: c66/207 lr:0.000774 t:6.8s +tttg: c67/207 lr:0.000767 t:6.9s +tttg: c68/207 lr:0.000761 t:7.0s +tttg: c69/207 lr:0.000754 t:7.1s +tttg: c70/207 lr:0.000748 t:7.1s +tttg: c71/207 lr:0.000741 t:7.2s +tttg: c72/207 lr:0.000734 t:7.3s +tttg: c73/207 lr:0.000728 t:7.4s +tttg: c74/207 lr:0.000721 t:7.4s +tttg: c75/207 lr:0.000714 t:7.5s +tttg: c76/207 lr:0.000707 t:7.6s +tttg: c77/207 lr:0.000700 t:7.7s +tttg: c78/207 lr:0.000693 t:7.7s +tttg: c79/207 lr:0.000686 t:7.8s +tttg: c80/207 lr:0.000679 t:7.9s +tttg: c81/207 lr:0.000672 t:7.9s +tttg: c82/207 lr:0.000665 t:8.0s +tttg: c83/207 lr:0.000657 t:8.1s +tttg: c84/207 lr:0.000650 t:8.2s +tttg: c85/207 lr:0.000643 t:8.2s +tttg: c86/207 lr:0.000636 t:8.3s +tttg: c87/207 lr:0.000628 t:8.4s +tttg: c88/207 lr:0.000621 t:8.5s +tttg: c89/207 lr:0.000613 t:8.5s +tttg: c90/207 lr:0.000606 t:8.6s +tttg: c91/207 lr:0.000598 t:8.7s +tttg: c92/207 lr:0.000591 t:8.8s +tttg: c93/207 lr:0.000583 t:8.8s +tttg: c94/207 lr:0.000576 t:8.9s +tttg: c95/207 lr:0.000568 t:9.0s +tttg: c96/207 lr:0.000561 t:9.1s +tttg: c97/207 lr:0.000553 t:9.1s +tttg: c98/207 lr:0.000546 t:9.2s +tttg: c99/207 lr:0.000538 t:9.3s +tttg: c100/207 lr:0.000530 t:9.3s +tttg: c101/207 lr:0.000523 t:9.4s +tttg: c102/207 lr:0.000515 t:9.5s +tttg: c103/207 lr:0.000508 t:9.6s +tttg: c104/207 lr:0.000500 t:9.6s +tttg: c105/207 lr:0.000492 t:9.7s +tttg: c106/207 lr:0.000485 t:9.8s +tttg: c107/207 lr:0.000477 t:9.9s +tttg: c108/207 lr:0.000470 t:9.9s +tttg: c109/207 lr:0.000462 t:10.0s +tttg: c110/207 lr:0.000454 t:10.1s +tttg: c111/207 lr:0.000447 t:10.2s +tttg: c112/207 lr:0.000439 t:10.2s +tttg: c113/207 lr:0.000432 t:10.3s +tttg: c114/207 lr:0.000424 t:10.4s +tttg: c115/207 lr:0.000417 t:10.5s +tttg: c116/207 lr:0.000409 t:10.5s +tttg: c117/207 lr:0.000402 t:10.6s +tttg: c118/207 lr:0.000394 t:10.7s +tttg: c119/207 lr:0.000387 t:10.8s +tttg: c120/207 lr:0.000379 t:10.8s +tttg: c121/207 lr:0.000372 t:10.9s +tttg: c122/207 lr:0.000364 t:11.0s +tttg: c123/207 lr:0.000357 t:11.1s +tttg: c124/207 lr:0.000350 t:11.2s +tttg: c125/207 lr:0.000343 t:11.2s +tttg: c126/207 lr:0.000335 t:11.3s +tttg: c127/207 lr:0.000328 t:11.4s +tttg: c128/207 lr:0.000321 t:11.5s +tttg: c129/207 lr:0.000314 t:11.5s +tttg: c130/207 lr:0.000307 t:11.6s +tttg: c131/207 lr:0.000300 t:11.7s +tttg: c132/207 lr:0.000293 t:11.8s +tttg: c133/207 lr:0.000286 t:11.8s +tttg: c134/207 lr:0.000279 t:11.9s +tttg: c135/207 lr:0.000272 t:12.0s +tttg: c136/207 lr:0.000266 t:12.1s +tttg: c137/207 lr:0.000259 t:12.1s +tttg: c138/207 lr:0.000252 t:12.2s +tttg: c139/207 lr:0.000246 t:12.3s +tttg: c140/207 lr:0.000239 t:12.3s +tttg: c141/207 lr:0.000233 t:12.4s +tttg: c142/207 lr:0.000226 t:12.5s +tttg: c143/207 lr:0.000220 t:12.6s +tttg: c144/207 lr:0.000214 t:12.6s +tttg: c145/207 lr:0.000207 t:12.7s +tttg: c146/207 lr:0.000201 t:12.8s +tttg: c147/207 lr:0.000195 t:12.9s +tttg: c148/207 lr:0.000189 t:12.9s +tttg: c149/207 lr:0.000183 t:13.0s +tttg: c150/207 lr:0.000177 t:13.1s +tttg: c151/207 lr:0.000172 t:13.2s +tttg: c152/207 lr:0.000166 t:13.2s +tttg: c153/207 lr:0.000160 t:13.3s +tttg: c154/207 lr:0.000155 t:13.4s +tttg: c155/207 lr:0.000149 t:13.5s +tttg: c156/207 lr:0.000144 t:13.5s +tttg: c157/207 lr:0.000138 t:13.6s +tttg: c158/207 lr:0.000133 t:13.7s +tttg: c159/207 lr:0.000128 t:13.8s +tttg: c160/207 lr:0.000123 t:13.8s +tttg: c161/207 lr:0.000118 t:13.9s +tttg: c162/207 lr:0.000113 t:14.0s +tttg: c163/207 lr:0.000108 t:14.1s +tttg: c164/207 lr:0.000104 t:14.1s +tttg: c165/207 lr:0.000099 t:14.2s +tttg: c166/207 lr:0.000095 t:14.3s +tttg: c167/207 lr:0.000090 t:14.4s +tttg: c168/207 lr:0.000086 t:14.4s +tttg: c169/207 lr:0.000082 t:14.5s +tttg: c170/207 lr:0.000078 t:14.6s +tttg: c171/207 lr:0.000073 t:14.7s +tttg: c172/207 lr:0.000070 t:14.7s +tttg: c173/207 lr:0.000066 t:14.8s +tttg: c174/207 lr:0.000062 t:14.9s +tttg: c175/207 lr:0.000058 t:15.0s +tttg: c176/207 lr:0.000055 t:15.0s +tttg: c177/207 lr:0.000051 t:15.1s +tttg: c178/207 lr:0.000048 t:15.2s +tttg: c179/207 lr:0.000045 t:15.3s +tttg: c180/207 lr:0.000042 t:15.3s +tttg: c181/207 lr:0.000039 t:15.4s +tttg: c182/207 lr:0.000036 t:15.5s +tttg: c183/207 lr:0.000033 t:15.6s +tttg: c184/207 lr:0.000030 t:15.6s +tttg: c185/207 lr:0.000028 t:15.7s +tttg: c186/207 lr:0.000025 t:15.8s +tttg: c187/207 lr:0.000023 t:15.9s +tttg: c188/207 lr:0.000021 t:15.9s +tttg: c189/207 lr:0.000019 t:16.0s +tttg: c190/207 lr:0.000017 t:16.1s +tttg: c191/207 lr:0.000015 t:16.2s +tttg: c192/207 lr:0.000013 t:16.2s +tttg: c193/207 lr:0.000011 t:16.3s +tttg: c194/207 lr:0.000010 t:16.4s +tttg: c195/207 lr:0.000008 t:16.5s +tttg: c196/207 lr:0.000007 t:16.5s +tttg: c197/207 lr:0.000006 t:16.6s +tttg: c198/207 lr:0.000005 t:16.7s +tttg: c199/207 lr:0.000004 t:16.8s +tttg: c200/207 lr:0.000003 t:16.8s +tttg: c201/207 lr:0.000002 t:16.9s +tttg: c202/207 lr:0.000001 t:17.0s +tttg: c203/207 lr:0.000001 t:17.1s +tttg: c204/207 lr:0.000001 t:17.1s +tttg: c205/207 lr:0.000000 t:17.2s +tttg: c206/207 lr:0.000000 t:17.3s +ttpr: phase:1/1 t:303.8s +ttp: b736/782 bl:2.7069 bb:1.0378 rl:2.7414 rb:1.0760 dl:2143-2167 gd:1 +ttp: b729/782 bl:2.7191 bb:1.0371 rl:2.7400 rb:1.0736 dl:1977-1998 gd:1 +ttp: b722/782 bl:2.7722 bb:1.0577 rl:2.7418 rb:1.0727 dl:1847-1863 gd:1 +ttp: b715/782 bl:2.7121 bb:1.0653 rl:2.7404 rb:1.0723 dl:1726-1742 gd:1 +ttp: b706/782 bl:2.7730 bb:1.0400 rl:2.7418 rb:1.0709 dl:1618-1627 gd:1 +ttp: b699/782 bl:2.7815 bb:1.0424 rl:2.7433 rb:1.0697 dl:1543-1553 gd:1 +ttp: b691/782 bl:2.7207 bb:1.0440 rl:2.7425 rb:1.0688 dl:1468-1476 gd:1 +ttp: b683/782 bl:2.7970 bb:1.0610 rl:2.7443 rb:1.0685 dl:1401-1408 gd:1 +ttp: b675/782 bl:2.8705 bb:1.0926 rl:2.7483 rb:1.0693 dl:1342-1348 gd:1 +ttp: b666/782 bl:2.7188 bb:1.0311 rl:2.7474 rb:1.0681 dl:1283-1289 gd:1 +ttp: b658/782 bl:2.7977 bb:1.0611 rl:2.7488 rb:1.0679 dl:1233-1240 gd:1 +ttp: b650/782 bl:2.8301 bb:1.0812 rl:2.7508 rb:1.0683 dl:1187-1193 gd:1 +ttp: b642/782 bl:2.7675 bb:1.0802 rl:2.7512 rb:1.0686 dl:1145-1150 gd:1 +ttp: b636/782 bl:2.7292 bb:1.0490 rl:2.7507 rb:1.0681 dl:1116-1119 gd:1 +ttp: b629/782 bl:2.7241 bb:1.0375 rl:2.7502 rb:1.0674 dl:1082-1087 gd:1 +ttp: b623/782 bl:2.8105 bb:1.0584 rl:2.7514 rb:1.0672 dl:1055-1059 gd:1 +ttp: b615/782 bl:2.7864 bb:1.0549 rl:2.7521 rb:1.0670 dl:1020-1024 gd:1 +ttp: b609/782 bl:2.8036 bb:1.0634 rl:2.7530 rb:1.0669 dl:994-998 gd:1 +ttp: b602/782 bl:2.7546 bb:1.0606 rl:2.7531 rb:1.0668 dl:967-970 gd:1 +ttp: b595/782 bl:2.8175 bb:1.0763 rl:2.7542 rb:1.0670 dl:940-943 gd:1 +ttp: b588/782 bl:2.7092 bb:1.0301 rl:2.7534 rb:1.0664 dl:917-921 gd:1 +ttp: b581/782 bl:2.7219 bb:1.0108 rl:2.7529 rb:1.0655 dl:894-896 gd:1 +ttp: b573/782 bl:2.8034 bb:1.0548 rl:2.7537 rb:1.0653 dl:868-871 gd:1 +ttp: b565/782 bl:2.7743 bb:1.0862 rl:2.7540 rb:1.0656 dl:842-845 gd:1 +ttp: b557/782 bl:2.7762 bb:1.0442 rl:2.7543 rb:1.0653 dl:818-820 gd:1 +ttp: b548/782 bl:2.7355 bb:1.0327 rl:2.7540 rb:1.0649 dl:793-795 gd:1 +ttp: b540/782 bl:2.7488 bb:1.0374 rl:2.7540 rb:1.0645 dl:772-773 gd:1 +ttp: b532/782 bl:2.8452 bb:1.0718 rl:2.7551 rb:1.0646 dl:751-754 gd:1 +ttp: b524/782 bl:2.7768 bb:1.0597 rl:2.7553 rb:1.0645 dl:732-735 gd:1 +ttp: b516/782 bl:2.7879 bb:1.0545 rl:2.7557 rb:1.0644 dl:711-714 gd:1 +ttp: b508/782 bl:2.7162 bb:1.0269 rl:2.7553 rb:1.0640 dl:693-695 gd:1 +ttp: b500/782 bl:2.8499 bb:1.0752 rl:2.7563 rb:1.0641 dl:675-677 gd:1 +ttp: b492/782 bl:2.7967 bb:1.0459 rl:2.7567 rb:1.0639 dl:657-659 gd:1 +ttp: b484/782 bl:2.7563 bb:1.0610 rl:2.7567 rb:1.0639 dl:640-643 gd:1 +ttp: b476/782 bl:2.7662 bb:1.0563 rl:2.7567 rb:1.0638 dl:623-625 gd:1 +ttp: b470/782 bl:2.8207 bb:1.0937 rl:2.7573 rb:1.0641 dl:612-614 gd:1 +ttp: b462/782 bl:2.7395 bb:1.0421 rl:2.7572 rb:1.0639 dl:597-598 gd:1 +ttp: b454/782 bl:2.8392 bb:1.0929 rl:2.7579 rb:1.0642 dl:582-584 gd:1 +ttp: b445/782 bl:2.8225 bb:1.0551 rl:2.7584 rb:1.0641 dl:566-567 gd:1 +ttp: b437/782 bl:2.7418 bb:1.0181 rl:2.7583 rb:1.0637 dl:551-553 gd:1 +ttp: b429/782 bl:2.7497 bb:1.0450 rl:2.7582 rb:1.0636 dl:537-538 gd:1 +ttp: b421/782 bl:2.8040 bb:1.0446 rl:2.7585 rb:1.0634 dl:523-524 gd:1 +ttp: b413/782 bl:2.8041 bb:1.0888 rl:2.7589 rb:1.0636 dl:510-511 gd:1 +ttp: b405/782 bl:2.8128 bb:1.0804 rl:2.7592 rb:1.0637 dl:497-498 gd:1 +ttp: b398/782 bl:2.8485 bb:1.1070 rl:2.7598 rb:1.0640 dl:486-487 gd:1 +ttp: b388/782 bl:2.9086 bb:1.1064 rl:2.7608 rb:1.0643 dl:470-471 gd:1 +ttp: b380/782 bl:2.7910 bb:1.0758 rl:2.7610 rb:1.0643 dl:459-460 gd:1 +ttp: b372/782 bl:2.8062 bb:1.0592 rl:2.7613 rb:1.0643 dl:447-448 gd:1 +ttp: b364/782 bl:2.7843 bb:1.0728 rl:2.7614 rb:1.0644 dl:435-437 gd:1 +ttp: b355/782 bl:2.7779 bb:1.0875 rl:2.7615 rb:1.0645 dl:423-424 gd:1 +ttp: b347/782 bl:2.9104 bb:1.1100 rl:2.7623 rb:1.0647 dl:413-414 gd:1 +ttp: b338/782 bl:2.7859 bb:1.0845 rl:2.7624 rb:1.0649 dl:401-402 gd:1 +ttp: b331/782 bl:2.8718 bb:1.1139 rl:2.7630 rb:1.0651 dl:391-393 gd:1 +ttp: b323/782 bl:2.8021 bb:1.0917 rl:2.7632 rb:1.0652 dl:381-382 gd:1 +ttp: b317/782 bl:2.7683 bb:1.0889 rl:2.7632 rb:1.0653 dl:373-374 gd:1 +ttp: b308/782 bl:2.8453 bb:1.0929 rl:2.7636 rb:1.0655 dl:362-363 gd:1 +ttp: b299/782 bl:2.8479 bb:1.0881 rl:2.7640 rb:1.0656 dl:351-352 gd:1 +ttp: b291/782 bl:2.8572 bb:1.1159 rl:2.7644 rb:1.0658 dl:341-342 gd:1 +ttp: b283/782 bl:2.8205 bb:1.0748 rl:2.7646 rb:1.0658 dl:332-333 gd:1 +ttp: b275/782 bl:2.7953 bb:1.0866 rl:2.7648 rb:1.0659 dl:322-324 gd:1 +ttp: b267/782 bl:2.8140 bb:1.0584 rl:2.7650 rb:1.0659 dl:314-315 gd:1 +ttp: b259/782 bl:2.8238 bb:1.0973 rl:2.7652 rb:1.0660 dl:305-306 gd:1 +ttp: b251/782 bl:2.7431 bb:1.1014 rl:2.7651 rb:1.0661 dl:296-297 gd:1 +ttp: b243/782 bl:2.7580 bb:1.0737 rl:2.7651 rb:1.0662 dl:288-289 gd:1 +ttp: b237/782 bl:2.7647 bb:1.1091 rl:2.7651 rb:1.0663 dl:282-283 gd:1 +ttp: b229/782 bl:2.8535 bb:1.0955 rl:2.7654 rb:1.0664 dl:274-275 gd:1 +ttp: b221/782 bl:2.8606 bb:1.1214 rl:2.7657 rb:1.0666 dl:266-267 gd:1 +ttp: b213/782 bl:2.8797 bb:1.1156 rl:2.7661 rb:1.0667 dl:258-259 gd:1 +ttp: b205/782 bl:2.8678 bb:1.1154 rl:2.7664 rb:1.0669 dl:251-252 gd:1 +ttp: b197/782 bl:2.9860 bb:1.2063 rl:2.7670 rb:1.0673 dl:244-245 gd:1 +ttp: b189/782 bl:2.8754 bb:1.1287 rl:2.7674 rb:1.0675 dl:237-238 gd:1 +ttp: b181/782 bl:2.8447 bb:1.1110 rl:2.7676 rb:1.0676 dl:230-230 gd:1 +ttp: b172/782 bl:2.9873 bb:1.1656 rl:2.7682 rb:1.0679 dl:222-223 gd:1 +ttp: b164/782 bl:2.9503 bb:1.1489 rl:2.7687 rb:1.0681 dl:215-216 gd:1 +ttp: b157/782 bl:2.9264 bb:1.1015 rl:2.7691 rb:1.0682 dl:209-210 gd:1 +ttp: b148/782 bl:2.9373 bb:1.1190 rl:2.7695 rb:1.0683 dl:202-203 gd:1 +ttp: b139/782 bl:2.9875 bb:1.1636 rl:2.7700 rb:1.0685 dl:194-195 gd:1 +ttp: b131/782 bl:3.0113 bb:1.1813 rl:2.7705 rb:1.0688 dl:188-189 gd:1 +ttp: b123/782 bl:2.8461 bb:1.1438 rl:2.7707 rb:1.0689 dl:182-183 gd:1 +ttp: b115/782 bl:2.8788 bb:1.1321 rl:2.7709 rb:1.0690 dl:176-177 gd:1 +ttp: b106/782 bl:3.0807 bb:1.2089 rl:2.7716 rb:1.0693 dl:170-171 gd:1 +ttp: b102/782 bl:2.9150 bb:1.1478 rl:2.7718 rb:1.0695 dl:167-167 gd:1 +ttp: b92/782 bl:3.0025 bb:1.1996 rl:2.7723 rb:1.0697 dl:159-160 gd:1 +ttp: b86/782 bl:2.9798 bb:1.2165 rl:2.7727 rb:1.0700 dl:155-155 gd:1 +ttp: b77/782 bl:3.0076 bb:1.1636 rl:2.7731 rb:1.0701 dl:148-148 gd:1 +ttp: b66/782 bl:3.0913 bb:1.2254 rl:2.7736 rb:1.0704 dl:139-140 gd:1 +ttp: b59/782 bl:2.9484 bb:1.1882 rl:2.7739 rb:1.0706 dl:134-134 gd:1 +ttp: b51/782 bl:3.0255 bb:1.2107 rl:2.7743 rb:1.0708 dl:127-128 gd:1 +ttp: b44/782 bl:3.1351 bb:1.2219 rl:2.7748 rb:1.0710 dl:122-122 gd:1 +ttp: b34/782 bl:3.0125 bb:1.2381 rl:2.7751 rb:1.0712 dl:114-115 gd:1 +ttp: b27/782 bl:3.0690 bb:1.2362 rl:2.7755 rb:1.0714 dl:107-108 gd:1 +ttp: b17/782 bl:3.1693 bb:1.2866 rl:2.7759 rb:1.0716 dl:98-99 gd:1 +ttp: b7/782 bl:3.2725 bb:1.2719 rl:2.7764 rb:1.0718 dl:84-86 gd:1 +quantized_ttt_phased val_loss:2.76784056 val_bpb:1.07170518 eval_time:486863ms +total_eval_time:486.9s diff --git a/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/README.md b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/README.md new file mode 100644 index 0000000000..28594bb8c4 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/README.md @@ -0,0 +1,150 @@ +# QAT Dead-Code Bug + 7 Untried Techniques: What I Found + +**Non-Record Submission (Research Contribution)** +**Author:** [@wfproc](https://github.com/wfproc) +**Base:** PR #549 by @abaybektursun (1.1194 bpb, current SOTA) +**Hardware:** 1xH100 80GB SXM, 600s wallclock (~940 steps per run) +**Best pre-quant:** 1.3646 val_bpb | **Best post-quant (int6 sliding):** 2.3411 val_bpb + +Not competitive with the 8xH100 leaderboard (7x fewer steps), but the findings transfer to any hardware. + +--- + +## The Short Version + +I spent two days running 20+ experiments on the SOTA #549 stack, testing techniques nobody else had submitted results for. The main finding: Late QAT is dead code in #315-derived submissions due to a torch.compile bug. Fixing it actually made the int6 gap *worse*, not better. I also tested 7 novel techniques from recent papers (Muon-VS, Deep Delta Learning, stable depth recurrence, anti-layer removal, NS step reduction, spectral SVD compression, wider models). All negative. The throughput tax at this scale kills everything that adds overhead. + +--- + +## The QAT Dead-Code Bug + +PR #315's README flagged this already, but I wanted to confirm it and try to fix it. + +The issue: `CastedLinear._qat_enabled` is a class attribute set to `False`. When `torch.compile(fullgraph=True)` traces the model, it sees `_qat_enabled=False` and eliminates the STE branch from the compiled graph. Setting it to `True` later (when Late QAT is supposed to activate) does nothing. The compiled code is frozen. + +This is present in the SOTA (#549). I verified by checking step times: with the original code, step time stays flat at ~625ms throughout training. If QAT was actually running, you'd see a jump when it activates (STE adds overhead). + +I tried three ways to fix it: + +| Fix | What happened | +|---|---| +| Mutable list `[False]` | `FailOnRecompileLimitHit` error, `fullgraph=True` won't recompile | +| Instance attribute per module | Same recompilation error | +| Tensor multiplier (0.0 or 1.0) | Works. Graph shape stays the same, no recompile needed | + +The tensor approach: `w = w + qat_scale * (w_q - w).detach()`. When scale is 0.0, it's a no-op. When 1.0, full STE. The compiled graph handles both without recompilation. + +With the fix, step time jumps from ~625ms to ~670ms when QAT activates. So QAT is actually running now. But the results got worse: + +| Metric | Dead QAT (original) | Working QAT (tensor fix) | +|---|---|---| +| val_bpb pre-quant | 1.3631 | 1.3646 | +| val_bpb int6 sliding | 2.3107 | 2.3411 | + +Why? A few theories: + +1. At only ~940 steps (1xH100), the model hasn't converged enough for STE to help. The weights are still moving too fast. +2. WD=0.04 + EMA 0.997 already push weights toward int6-friendly distributions without needing explicit fake quantization. +3. The tensor multiply `qat_scale * (w_q - w).detach()` always executes (even at scale=0.0), adding a tiny overhead that costs a few training steps. + +I can't rule out that QAT would help at 7000 steps (8xH100). But it's also possible that the dead code is genuinely the right call for this stack. + +--- + +## Novel Technique Sweep + +All tested on the SOTA #549 stack on 1xH100. Baseline: 1.3631 val_bpb (pre-quant, slope=0.75, 947 steps). + +| # | Technique | Paper | val_bpb | Delta | Why it failed | +|---|---|---|---|---|---| +| 1 | Muon-VS (variance-adaptive) | 2601.14603 | 1.3884 | +0.019 | Variance buffer needs 100s of steps to warm up, +2.2% overhead | +| 2 | Deep Delta Learning | 2601.00417 | 1.3720 | +0.003 | Beta params stay near zero at init, within noise | +| 3 | Thinking Deeper recurrence | 2603.21676 | 1.4163 | +0.053 | 21% step overhead, 163 fewer steps | +| 4 | Anti-layer removal | 2603.19348 | n/a | n/a | No anti-layers found in 11L stack | +| 5 | Newton-Schulz steps=3 | - | 1.3798 | +0.017 | Worse orthogonalization quality, +9 steps doesn't compensate | +| 6 | Spectral SVD (640d 12L) | Novel | 1.4646 | +0.102 | 46% slower, only 649 steps, SVD can't separate signal from noise | +| 7 | Wider model (576d 11L) | - | 1.4504 | +0.087 | 33% slower, only 711 steps | + +Also confirmed LeakyReLU(0.75)^2 > LeakyReLU(0.5)^2 by 0.006 bpb, matching PR #977. + +### The throughput tax + +This keeps coming up: at ~83ms/step on H100, each millisecond of added overhead costs roughly 0.007 bpb. Techniques 1, 3, 6, and 7 all added overhead. None of them delivered enough quality improvement per step to compensate. The only technique that worked (LeakyReLU slope change) adds zero overhead. + +### Anti-layer results (full table) + +Trained 943 steps, then zeroed each layer's attn_scale + mlp_scale and re-evaluated: + +| Layer | Delta bpb when removed | Role | +|---|---|---| +| 0 | +1.127 | Critical (embedding interaction) | +| 1 | +0.346 | Critical | +| 2 | +0.237 | Important | +| 3 | +0.087 | Moderate | +| 4 | +0.081 | Moderate | +| 5-6 | +0.036 | Low | +| 7-8 | +0.031-0.034 | Low (layer 8 least important) | +| 9-10 | +0.040-0.044 | Low | + +No layer hurts performance when present. The LN Scale (1/sqrt(L+1)) in the stack may prevent anti-layer formation by dampening deep layers. + +--- + +## Other Findings + +**SWA during QAT:** PR #989 found SWA sabotages QAT. I added a one-line fix that stops SWA accumulation when QAT activates. Worth adopting if you use both. + +**Prune-then-quantize:** Implemented but only partially tested. Prunes smallest-magnitude weights before int6 quantization (arXiv:2603.18426). One partial run showed neutral results. The pruned zeros don't compress better than quantized values under lzma, so the artifact size didn't shrink. Needs full evaluation. + +--- + +## Reproducing These Results + +From the repo root: +```bash +# Download data (if not already present) +python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10 + +# Main run (slope=0.75, includes QAT tensor-scale fix) +RUN_ID=main \ +LEAKY_SLOPE=0.75 \ +torchrun --standalone --nproc_per_node=1 \ +records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py + +# With prune-then-quantize (5%) +RUN_ID=prune5 \ +LEAKY_SLOPE=0.75 \ +PRUNE_FRACTION=0.05 \ +torchrun --standalone --nproc_per_node=1 \ +records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py + +# With anti-layer diagnostic (runs layer ablation after training) +RUN_ID=antilayer \ +LEAKY_SLOPE=0.75 \ +ANTILAYER_DIAGNOSTIC=1 \ +torchrun --standalone --nproc_per_node=1 \ +records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py +``` + +Other toggles: `MUON_VS=1`, `DEEP_DELTA=1`, `RECURRENCE_LAYERS=2 RECURRENCE_STEPS=2`, `SVD_RANK=128 MODEL_DIM=640 NUM_LAYERS=12`. + +--- + +## What I'd try next with more compute + +1. QAT fix at 7000 steps (8xH100) to see if the STE actually helps when the model is better converged +2. Spectral SVD compression at full scale (the concept needs 7000+ steps for weight matrices to develop spectral structure) +3. Prune-then-quantize sweep (5%, 10%, 15%) +4. LeakyReLU slope sweep at 0.80, 0.85, 0.90 + +--- + +## Included Files + +- `train_gpt.py` - modified SOTA #549 script with all techniques as env var toggles +- `submission.json` - metadata +- `train.log` - QAT fix run output (slope=0.75 with tensor-scale QAT, 1xH100, 943 steps) + +## Acknowledgments + +Built on PR #549 (@abaybektursun), PR #315 (@jfprincz, first to flag the QAT torch.compile issue), PR #977 (@michaelwinczuk, LeakyReLU 0.75), PR #989 (@alexanderaperry-arch, SWA-QAT finding). Issue #140 commentary by @notapplica was really helpful for prioritizing what to try. diff --git a/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/requirements.txt b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/requirements.txt new file mode 100644 index 0000000000..a653ec4978 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/requirements.txt @@ -0,0 +1,5 @@ +torch>=2.6.0 +numpy +sentencepiece +flash-attn +zstandard diff --git a/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/submission.json b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/submission.json new file mode 100644 index 0000000000..ba0aa4dfa7 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/submission.json @@ -0,0 +1,18 @@ +{ + "author": "wfproc", + "github_id": "wfproc", + "name": "QAT Dead-Code Analysis + Novel Technique Sweep (1xH100)", + "blurb": "Non-record research submission. Found that torch.compile constant-folds Late QAT in #315-derived code (confirmed in SOTA #549). Tested a fix via tensor-scale STE, but working QAT actually made the int6 gap worse. Also swept 7 untried techniques from recent papers (Muon-VS, Deep Delta Learning, Thinking Deeper recurrence, anti-layer removal, NS step reduction, spectral SVD compression, wider models). All negative on 1xH100. Key finding: at 83ms/step, each 1ms of overhead costs about 0.007 bpb.", + "date": "2026-03-28T00:00:00Z", + "track": "non-record-16mb", + "val_loss": 2.3041, + "val_bpb": 1.3646, + "pre_quant_val_loss": null, + "pre_quant_val_bpb": 1.3646, + "step_stop": 943, + "wallclock_seconds": 600, + "bytes_total": 7303141, + "bytes_model_int6_lzma": 7200188, + "bytes_code": 102953, + "gpu": "1xH100-80GB-SXM" +} diff --git a/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train.log b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train.log new file mode 100644 index 0000000000..1653c0a3a0 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train.log @@ -0,0 +1,16 @@ +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +step:0/20000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.02ms +swa:start step:250 +late_qat:enabled step:416 scale:0.1499 +step:500/20000 train_loss:2.4300 train_time:318487ms step_avg:636.97ms +step:943/20000 val_loss:2.3041 val_bpb:1.3646 train_time:600509ms step_avg:636.81ms +stopping_early: wallclock_cap train_time:600509ms step:943/20000 +DIAGNOSTIC post_ema val_loss:2.4761 val_bpb:1.4665 eval_time:15592ms +Serialized model: 106158518 bytes +Serialized model int6+lzma: 7200188 bytes +Total submission size int6+lzma: 7303141 bytes +final_int6_roundtrip val_loss:3.9539 val_bpb:2.3417 eval_time:28383ms +final_int6_roundtrip_exact val_loss:3.95388609 val_bpb:2.34171405 +final_int6_sliding_window val_loss:3.9528 val_bpb:2.3411 stride:64 eval_time:565334ms +final_int6_sliding_window_exact val_loss:3.95279298 val_bpb:2.34107285 +final_int8_zlib_roundtrip_exact val_loss:3.95279298 val_bpb:2.34107285 diff --git a/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py new file mode 100644 index 0000000000..0ab96c73c6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-28_QAT_DeadCode_Analysis_NovelTechniques_1xH100/train_gpt.py @@ -0,0 +1,2144 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + HAS_FA3 = True +except ImportError: + HAS_FA3 = False +is_rocm = hasattr(torch.version, "hip") and torch.version.hip is not None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_vs = bool(int(os.environ.get("MUON_VS", "0"))) # Variance-Adaptive Muon (arXiv:2601.14603) + deep_delta = bool(int(os.environ.get("DEEP_DELTA", "0"))) # Deep Delta Learning (arXiv:2601.00417) + leaky_slope = float(os.environ.get("LEAKY_SLOPE", "0.5")) # LeakyReLU negative slope (0.5=SOTA, 0.75=PR#977) + # Thinking Deeper: stable depth recurrence (arXiv:2603.21676) + recurrence_layers = int(os.environ.get("RECURRENCE_LAYERS", "0")) + recurrence_steps = int(os.environ.get("RECURRENCE_STEPS", "1")) + # Spectral Weight Compression: train wide, SVD-compress into artifact + svd_rank = int(os.environ.get("SVD_RANK", "0")) # 0 = disabled, >0 = compress each weight to this rank at export + spectral_reg = float(os.environ.get("SPECTRAL_REG", "0.0")) # nuclear norm regularizer weight + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0, variance_scaling: bool = False): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, variance_scaling=variance_scaling), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + meta = { + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + } + # Muon-VS: variance buffer for sharded path + vs = self.param_groups[0].get("variance_scaling", False) + if vs: + meta['shard_var'] = torch.zeros(shard_B, *tail, device=dev, dtype=torch.float32) + self._bank_meta.append(meta) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + vs = group.get("variance_scaling", False) + + # Muon-VS: update variance surrogate before momentum update + if vs: + if sharded: + var_buf = m.get('shard_var') + else: + state = self.state[p] + if "variance_buffer" not in state: + state["variance_buffer"] = torch.zeros_like(g, dtype=torch.float32) + var_buf = state["variance_buffer"] + # Gamma_t = beta * Gamma_{t-1} + beta*(1-beta)*(M_{t-1} - G_t)^2 + diff = (buf.float() - g.float()) + var_buf.mul_(momentum).add_(diff.square(), alpha=momentum * (1 - momentum)) + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + # Muon-VS: divide by sqrt(variance) before orthogonalization + if vs: + update = update.float() / (var_buf.sqrt() + 1e-8) + update = update.bfloat16() + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +_QAT_SCALE = torch.tensor(0.0) # 0.0 = QAT off, >0.0 = QAT on. Tensor so dynamo doesn't guard on value. + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + # STE int6 fake quantization - controlled by _QAT_SCALE tensor + # Using torch.where on a tensor avoids torch.compile recompilation + if self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + # Blend: when _QAT_SCALE=0, w is unchanged. When _QAT_SCALE=1, full STE. + qat_s = _QAT_SCALE.to(device=w.device, dtype=w.dtype) + w = w + qat_s * (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def _set_qat(enabled: bool) -> None: + global _QAT_SCALE + _QAT_SCALE = torch.tensor(1.0 if enabled else 0.0) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if HAS_FA3 and not is_rocm: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # Fallback: FA3 uses (B,T,H,D), SDPA uses (B,H,T,D) + y = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), + attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=Hyperparameters.leaky_slope) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + deep_delta: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + # Deep Delta Learning: rank-1 erasure gate on residual (arXiv:2601.00417) + self.deep_delta = deep_delta + if deep_delta: + self.dd_beta = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) # per-dim gate + # Identity-biased recurrence gate (arXiv:2603.21676) + self.recurrence_gate = None + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + # Deep Delta Learning: erase stale residual dimensions before adding new info + if self.deep_delta: + x_in = x_in * (1.0 + self.dd_beta.to(dtype=x_in.dtype)[None, None, :]) + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + deep_delta: bool = False, + recurrence_layers: int = 0, + recurrence_steps: int = 1, + ): + super().__init__() + self.recurrence_layers = recurrence_layers + self.recurrence_steps = recurrence_steps + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + deep_delta=deep_delta, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + # Thinking Deeper: add identity-biased recurrence gates + depth embeddings + if recurrence_layers > 0 and recurrence_steps > 1: + for i in range(max(0, num_layers - recurrence_layers), num_layers): + # Gate: concat(new, old) -> sigmoid -> blend + self.blocks[i].recurrence_gate = nn.Linear(model_dim * 2, model_dim, bias=True) + nn.init.zeros_(self.blocks[i].recurrence_gate.weight) + nn.init.constant_(self.blocks[i].recurrence_gate.bias, -2.0) # 88% retain old + # Per-step depth embeddings (small, learnable) + total_recurrence = recurrence_layers * recurrence_steps + self.depth_embeddings = nn.Parameter(torch.zeros(total_recurrence, model_dim) * 1e-4) + else: + self.depth_embeddings = None + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + # Thinking Deeper: repeat last recurrence_layers with identity-biased gates + if self.recurrence_layers > 0 and self.recurrence_steps > 1 and self.depth_embeddings is not None: + recur_start = n - self.recurrence_layers + depth_idx = 0 + for step in range(1, self.recurrence_steps): + for ri in range(self.recurrence_layers): + bi = recur_start + ri + x_prev = x + de = self.depth_embeddings[depth_idx].to(dtype=x.dtype)[None, None, :] + x_input = x + de + ve = self._get_ve(bi, input_ids, ve_cache) + x_new, _ = self.blocks[bi](x_input, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + gate = self.blocks[bi].recurrence_gate + if gate is not None: + z = torch.sigmoid(gate(torch.cat([x_new, x_prev], dim=-1))) + x = z * x_new + (1 - z) * x_prev + else: + x = x_new + depth_idx += 1 + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + # Thinking Deeper recurrence (same as forward) + if self.recurrence_layers > 0 and self.recurrence_steps > 1 and self.depth_embeddings is not None: + recur_start = n - self.recurrence_layers + depth_idx = 0 + for step in range(1, self.recurrence_steps): + for ri in range(self.recurrence_layers): + bi = recur_start + ri + x_prev = x + de = self.depth_embeddings[depth_idx].to(dtype=x.dtype)[None, None, :] + x_input = x + de + ve = self._get_ve(bi, input_ids, ve_cache) + x_new, _ = self.blocks[bi](x_input, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + gate = self.blocks[bi].recurrence_gate + if gate is not None: + z = torch.sigmoid(gate(torch.cat([x_new, x_prev], dim=-1))) + x = z * x_new + (1 - z) * x_prev + else: + x = x_new + depth_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = base_model.forward_logits if is_rocm else torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + if is_rocm: + enable_cudnn_sdp(False) + enable_flash_sdp(False) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + else: + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + try: + gpu_info = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout + except FileNotFoundError: + gpu_info = f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}" + log0(gpu_info, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.qat_enabled: + _set_qat(True) + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + deep_delta=args.deep_delta, + recurrence_layers=args.recurrence_layers, + recurrence_steps=args.recurrence_steps, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + torch._dynamo.config.cache_size_limit = 4 # Allow recompilation when QAT activates + compiled_model = base_model if is_rocm else torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + variance_scaling=args.muon_vs, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + qat_activated = args.qat_enabled # track whether QAT is active + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not qat_activated: + _set_qat(True) + qat_activated = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0 and not qat_activated: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + # === ANTI-LAYER DIAGNOSTIC (arXiv:2603.19348) - runs before quantization === + if bool(int(os.environ.get("ANTILAYER_DIAGNOSTIC", "0"))) and master_process: + log0("\n=== Anti-Layer Diagnostic ===") + log0("Testing removal of each layer one at a time...") + base_model.eval() + # Create fresh uncompiled copy for ablation - torch.compile traces break when scales change + torch._dynamo.reset() + import copy as copy_module + eager_model = copy_module.deepcopy(base_model) + eager_model._orig_mod = None # strip dynamo wrapper if present + # Remove any dynamo state + for attr in list(vars(eager_model)): + if attr.startswith('_dynamo') or attr.startswith('_torchdynamo'): + delattr(eager_model, attr) + with torch.inference_mode(): + bl_loss, bl_bpb = eval_val( + args, eager_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"Baseline (all layers): val_loss={bl_loss:.4f} val_bpb={bl_bpb:.4f}") + ablation_results = [] + for li in range(args.num_layers): + block = eager_model.blocks[li] + orig_as = block.attn_scale.data.clone() + orig_ms = block.mlp_scale.data.clone() + block.attn_scale.data.zero_() + block.mlp_scale.data.zero_() + with torch.inference_mode(): + ab_loss, ab_bpb = eval_val( + args, eager_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + delta = ab_bpb - bl_bpb + marker = " <<< ANTI-LAYER" if delta < 0 else "" + log0(f" Layer {li:2d}: val_bpb={ab_bpb:.4f} delta={delta:+.4f}{marker}") + ablation_results.append((li, ab_bpb, delta)) + eager_model.blocks[li].attn_scale.data.copy_(orig_as) + eager_model.blocks[li].mlp_scale.data.copy_(orig_ms) + log0("\n--- Sorted by delta (most harmful → most helpful) ---") + for idx, bpb, delta in sorted(ablation_results, key=lambda x: x[2]): + status = "ANTI-LAYER" if delta < 0 else ("critical" if delta > 0.05 else "neutral") + log0(f" Layer {idx:2d}: delta={delta:+.4f} ({status})") + anti = [i for i, _, d in ablation_results if d < 0] + if anti: + log0(f"\nAnti-layers found: {anti} - removing these could IMPROVE val_bpb!") + else: + log0("\nNo anti-layers found. All layers contribute positively.") + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + + # === Spectral Weight Compression: SVD-compress 2D weights to rank-r === + if args.svd_rank > 0: + svd_r = args.svd_rank + log0(f"svd_compress: rank={svd_r}") + total_orig = 0 + total_compressed = 0 + compressed_sd = {} + for name, t in unbanked_sd.items(): + if t.ndim == 2 and min(t.shape) > svd_r: + U, S, Vh = torch.linalg.svd(t.float(), full_matrices=False) + # Keep top-r components + U_r = U[:, :svd_r] # (m, r) + S_r = S[:svd_r] # (r,) + Vh_r = Vh[:svd_r, :] # (r, n) + # Reconstruct: absorb S into U for single matrix multiply at inference + W_approx = (U_r * S_r[None, :]) @ Vh_r + orig_params = t.numel() + new_params = svd_r * (t.shape[0] + t.shape[1]) + total_orig += orig_params + total_compressed += new_params + # Store as reconstructed full matrix (quantization handles the rest) + compressed_sd[name] = W_approx.to(t.dtype) + recon_err = (t.float() - W_approx).norm() / t.float().norm() + if "layer_0" in name or "layer_10" in name: + log0(f" {name}: {t.shape} rank {svd_r}, recon_err={recon_err:.4f}") + else: + compressed_sd[name] = t + log0(f"svd_compress: {total_orig/1e6:.1f}M -> effective {total_compressed/1e6:.1f}M params " + f"({total_orig/total_compressed:.1f}x spectral compression)") + unbanked_sd = compressed_sd + + # === Prune-Then-Quantize (arXiv:2603.18426): prune small weights BEFORE quantizing === + prune_fraction = float(os.environ.get("PRUNE_FRACTION", "0.0")) # 0.0 = disabled, 0.05 = prune 5% + if prune_fraction > 0: + log0(f"prune_then_quantize: pruning {prune_fraction*100:.1f}% of weights before int6") + pruned_sd = {} + total_pruned = 0 + total_weights = 0 + for name, t in unbanked_sd.items(): + if t.ndim == 2: + # Prune smallest-magnitude weights to zero + threshold = torch.quantile(t.float().abs().flatten(), prune_fraction) + mask = t.float().abs() > threshold + pruned_sd[name] = t * mask.to(t.dtype) + total_pruned += (~mask).sum().item() + total_weights += t.numel() + else: + pruned_sd[name] = t + log0(f"prune_then_quantize: pruned {total_pruned}/{total_weights} weights " + f"({total_pruned/max(total_weights,1)*100:.1f}%)") + unbanked_sd = pruned_sd + + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + _set_qat(False) # Disable QAT for eval + torch._dynamo.reset() # Clear compilation cache before compiling eval model + compiled_eval = eval_model if is_rocm else torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()