From e11ed4bd77b30062e4f86dbf9c56f8317d6fdea1 Mon Sep 17 00:00:00 2001 From: Viraj Deshwal Date: Tue, 31 Mar 2026 15:06:35 -0700 Subject: [PATCH] Record: Unified Attention + FA3 + Legal TTT (val_bpb=1.1412, 3-seed) --- README.md | 1 + .../README.md | 174 ++ .../submission.json | 14 + .../train.log | 1899 ++++++++++++ .../train_gpt.py | 1644 ++++++++++ .../train_seed1337.log | 1899 ++++++++++++ .../train_seed2025.log | 1899 ++++++++++++ .../train_seed42.log | 2760 +++++++++++++++++ 8 files changed, 10290 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/README.md create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/submission.json create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train.log create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed2025.log create mode 100644 records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed42.log diff --git a/README.md b/README.md index a447026f5c..39012623ea 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ Happy training! | Run | Score | Author | Summary | Date | Info | |-----|------:|--------|---------|------|------| +| 11L AR Self-Gen GPTQ + XSA | 1.1147 | abaybektursun | On PR #1019: Self-Generated GPTQ Calibration Data + all-layer XSA on the PR #549 stack | 2026-03-25 | [info](records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md) | | LeakyReLU² + Legal Score-First TTT + Parallel Muon | 1.1194 | abaybektursun | On PR #549: LeakyReLU(0.5)^2 + TTT + Parallel Muon on the PR #414 stack | 2026-03-23 | [info](records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md) | | 11L EMA + GPTQ-lite + warmdown3500 | 1.1228 | signalrush | On PR #374: GPTQ-lite clip search + EMA, plus warmdown3500 and QAT@0.15 | 2026-03-22 | [info](records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md) | | 11L Partial RoPE + LN Scale + EMA + XSA4 | 1.1248 | jfprincz | On PR #287: Partial RoPE (16/64) + layerwise LN scale | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md) | diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/README.md b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/README.md new file mode 100644 index 0000000000..7dffd9a040 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/README.md @@ -0,0 +1,174 @@ +# Unified Attention + FA3 Head-Dim Padding + Legal Score-First TTT + +**val_bpb: 1.1412** (3-seed mean, std 0.0008) | **~15.97 MB** | 8×H100 SXM + +## Results (8×H100 80GB SXM, PyTorch 2.8.0+cu128) + +| Seed | step_avg | steps | Pre-TTT bpb | **Post-TTT bpb** | TTT gain | TTT time | Artifact | +|------|----------|-------|-------------|-----------------|----------|----------|----------| +| 1337 | 49.6ms | 12,088 | 1.1647 | **1.1416** | -0.0231 | 408s | 15,991,687 | +| 42 | 49.6ms | 12,109 | 1.1647 | **1.1416** | -0.0231 | 407s | 15,988,916 | +| 2025 | 49.6ms | 12,103 | 1.1635 | **1.1403** | -0.0232 | 408s | 15,962,515 | +| **Mean** | **49.6ms** | **12,100** | **1.1643** | **1.1412 (std 0.0008)** | **-0.0231** | **~408s** | | + +## Key Innovation: Unified Attention + +Unified Attention ([Deshwal, 2026](https://github.com/ReinforceAI/yocto)) replaces the three separate Q, K, V projections in standard self-attention with a single W_unified projection. The output splits into three functional bands after the matmul. In our research, we found that one matrix carries enough geometric structure for all three roles. The amplitude and phase rotation across output dimensions create functional differentiation naturally, and three bands form on their own during training. + +```python +# Standard: 3 separate projections for routing +q = W_q @ x; k = W_k @ x; v = W_v @ x + +# Unified Attention: 1 projection, bands form naturally +unified = W_unified @ x +seeking, offering, content = unified.split(d//3, dim=-1) +``` + +The core insight is that attention is a routing mechanism, which decides which tokens talk to each other. The FFN does the heavy lifting and actually transforms information at each position. Unified attention cuts the routing budget and gives those bytes to the FFN. + +This matters most in parameter-constrained settings. In the 16 MB challenge, the bytes we save on attention go straight to the MLP: + +| | Standard (SOTA) | Unified (Ours) | +|---|---|---| +| Attention (compressed) | 5.10 MB | **2.82 MB** | +| MLP (compressed) | 10.21 MB | **12.70 MB** | + +We trade 2.28 MB of routing for 2.49 MB of computation. The MLP gets a bigger budget, and it shows. + +## Key Innovation: FA3 Head-Dim Padding + +Flash Attention 3 (Hopper) requires head_dim to be a multiple of 8. Our architecture uses head_dim=44 (from d=528, 4 heads). Rather than constraining the architecture, we **zero-pad to 48 dims before FA3 and slice back after**: + +```python +pad_n = (8 - head_dim % 8) % 8 # 4 for head_dim=44 +if pad_n > 0: + q = F.pad(q, (0, pad_n)) # [B,T,H,44] → [B,T,H,48] + k = F.pad(k, (0, pad_n)) + v = F.pad(v, (0, pad_n)) +out = flash_attn_func(q, k, v, causal=True) +y = out[..., :head_dim] # [B,T,H,48] → [B,T,H,44] +``` + +**Mathematically lossless**. Padded zeros contribute nothing to dot products or weighted sums. The 9% compute overhead from 44→48 dims is overwhelmed by FA3's 1.57× speedup over FA2/SDPA, giving a net **51ms/step** (vs 67ms SDPA, 65ms FA2). This unlocks **11,714 training steps** in 10 minutes, 40% more than FA2. + +## Legal TTT Protocol + +Backward-looking, score-first TTT following the PR #461 / PR #549 framework: + +1. Val tokens split into 1,893 non-overlapping 32K-token chunks +2. **For each chunk**: + - **SCORE**: Sliding window eval under `torch.no_grad()`. No weight mutation + - **TRAIN**: SGD(lr=0.002, momentum=0.9) on the already-scored chunk. 3 epochs, all blocks unfrozen, cosine LR decay, grad clip 1.0 +3. Last chunk scored but never trained on + +### TTT Hyperparameters + +| Parameter | Value | +|-----------|-------| +| Chunk size | 32,768 tokens | +| Optimizer | SGD + momentum(0.9) | +| Learning rate | 0.002 (cosine decay across chunks) | +| Epochs per chunk | 3 | +| Frozen blocks | None (all blocks adapt) | +| Gradient clip | 1.0 | +| Eval stride | 64 | + +### Timing Budget + +| Phase | Time | +|-------|------| +| Training | 600s (10 min) | +| Quantization + roundtrip validation | ~70s | +| Legal TTT (score-first + adaptation) | ~490s | +| **Total eval** | **~560s (< 10 min)** | + +## Training Architecture + +| Component | Setting | +|-----------|---------| +| **Attention** | **Unified Attention (single W_unified, 67% fewer attn params)** | +| Layers | 11 unique (K=11, R=1) | +| Dimension | 528 (component_dim=176, head_dim=44) | +| Heads | 4 | +| MLP | 3× (1584) with LeakyReLU(0.5)² | +| SmearGate | Position-mixing gate (zero-init sigmoid) | +| LN Scale | 1/√(layer+1) on norm outputs | +| VE128 | Value embedding on layers 9-10 | +| U-Net skips | Encoder-decoder connections | +| Weight avg | EMA(0.997) + Tight SWA(every 50) | +| Quantization | GPTQ-lite int6 + int8 embeddings + LZMA-6 | +| QAT | STE int6 at fraction ≥ 0.15 | +| Optimizer | Parallel Muon (batched NS5, async reduce-scatter) | +| **Flash Attention** | **FA3 (Hopper) with head_dim zero-padding 44→48** | +| Total params | 23,209,295 | + +### Parallel Muon with Parameter Banking + +4 contiguous 3D `nn.Parameter` banks replace 44 separate weight matrices: +- `unified_bank[K, d, d]`: unified attention projections +- `output_bank[K, d, comp]`: attention output projections +- `fc_bank[K, 3d, d]`: MLP up-projections +- `proj_bank[K, d, 3d]`: MLP down-projections + +Batched Newton-Schulz orthogonalization via 3D tensor operations. DDP replaced with async reduce-scatter → local NS5 → async all-gather. + +## Ablation + +Incremental contribution of each technique: + +| Change | BPB | Delta | +|--------|-----|-------| +| Baseline (9L, 512d, relu²) | 1.2244 | | +| + 10L, 3× MLP, seq 2048 | ~1.17 | ~-0.05 | +| + SmearGate | ~1.167 | -0.003 | +| + LN Scale | ~1.165 | -0.002 | +| + VE128 (layers 8-9) | ~1.161 | -0.004 | +| + Unified Attention (replaces Q/K/V) | ~1.161 | ±0.000 (same BPB, 67% fewer attn params) | +| + K=11 (depth > width) | ~1.155 | -0.006 | +| + FA3 with head-dim padding (51ms/step) | 1.1649 | -0.010 (more steps) | +| + Legal Score-First TTT | **1.1414** | -0.024 | + +## Negative Results + +| Technique | Impact | Root Cause | +|-----------|--------|------------| +| XSA (last 4 layers) | +0.0015 worse | Content band coupled to seeking/offering in shared projection | +| BigramHash at input | +0.009 worse | Single matrix can't route bigram across 3 functional bands | +| EMA 0.999 | +0.007 worse | Over-smooths weight distribution | +| Soft-sigmoid QAT | Training stall steps 2000-6000 | Ramping alpha creates unstable gradients; simple STE works | +| K=11 d=552 (mixed int5/int6) | Over budget (16.9 MB) | Unified attention weights have higher entropy than Q/K/V | + +## Requirements + +Flash Attention 3 (Hopper) is required: + +```bash +pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280 +pip install sentencepiece zstandard +``` + +## Run Command + +```bash +NUM_UNIQUE_LAYERS=11 MODEL_DIM=528 NUM_HEADS=4 \ +VE_LAYERS=9,10 \ +EMA_DECAY=0.997 QAT_START_FRACTION=0.15 \ +TRAIN_BATCH_TOKENS=524288 \ +SLIDING_WINDOW_EVAL=0 \ +VAL_LOSS_EVERY=3000 TRAIN_LOG_EVERY=500 \ +LEGAL_TTT_EPOCHS=3 \ +TTT_LORA_ATTN=0 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- **Unified Attention architecture**: This work (Viraj Deshwal, Reinforce AI) +- **FA3 head-dim padding**: This work +- **LeakyReLU² activation**: [PR #493](https://github.com/openai/parameter-golf/pull/493) by @parinzee +- **SmearGate**: [PR #65](https://github.com/openai/parameter-golf/pull/65) by @aquariouseworkman +- **LN Scale**: [PR #315](https://github.com/openai/parameter-golf/pull/315) by @jfprincz +- **VE128**: [PR #374](https://github.com/openai/parameter-golf/pull/374) by @unnir +- **Parameter Banking + Parallel Muon**: [PR #399](https://github.com/openai/parameter-golf/pull/399) by @abaybektursun +- **Legal TTT recipe**: [PR #461](https://github.com/openai/parameter-golf/pull/461) by @Christopher-Lee-McClendon +- **FA3 prebuilt wheels**: [windreamer/flash-attention3-wheels](https://github.com/windreamer/flash-attention3-wheels) \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/submission.json b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/submission.json new file mode 100644 index 0000000000..f860207be7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/submission.json @@ -0,0 +1,14 @@ +{ + "name": "Viraj Deshwal", + "github_id": "VirajDeshwal", + "val_bpb": 1.1412, + "val_bpb_std": 0.0008, + "seeds": [1337, 42, 2025], + "seed_results": [1.1416, 1.1416, 1.1403], + "artifact_bytes": 15981039, + "training_time_seconds": 600, + "eval_time_seconds": 408, + "gpu": "8xH100 80GB SXM", + "pytorch_version": "2.8.0+cu128", + "summary": "Unified Attention (single W_unified projection, 67% fewer attn params) + FA3 with head-dim zero-padding + Legal Score-First TTT" +} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train.log b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train.log new file mode 100644 index 0000000000..299ea54a45 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train.log @@ -0,0 +1,1899 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import traceback +import uuid +import zlib +import lzma +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_func +except ImportError: + raise ImportError( + "Flash Attention 3 (Hopper) is required. Install with:\n" + " pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280\n" + "Or see requirements.txt for details." + ) + + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("yocto-golf") + +def log_architecture(model, args): + n = sum(p.numel() for p in model.parameters()) + logger.info(f"YOCTO d={args.model_dim} K={args.num_unique_layers} heads={args.num_heads} params={n:,}") + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # ── Yocto architecture ── + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 552)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 10)) + num_recurrences = int(os.environ.get("NUM_RECURRENCES", 1)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + seeking_gain_init = float(os.environ.get("SEEKING_GAIN_INIT", 1.5)) + rope_fraction = float(os.environ.get("ROPE_FRACTION", 1.0)) # 1.0 = full RoPE, 0.5 = half partial RoPE + + # ── Optimizer ── + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # ── LR warmup (actual learning rate ramp, separate from compile warmup) ── + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 100)) + + # ── EMA ── + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # 0 = disabled, 0.997 = SOTA setting + + # ── SWA (Stochastic Weight Averaging) ── + swa_every = int(os.environ.get("SWA_EVERY", 50)) # 0 = disabled, 50 = SOTA setting + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) # only SWA when lr_scale < this + + # ── Compression ── + compression = os.environ.get("COMPRESSION", "lzma") # "zlib", "zstd", or "lzma" + + # ── QAT (Quantization-Aware Training) ── + qat_bits = int(os.environ.get("QAT_BITS", 6)) # 0 = disabled, 6 = int6 QAT + qat_start_fraction = float(os.environ.get("QAT_START_FRACTION", 0.15)) # when to start QAT + + # ── Mixed precision quantization ── + int5_layers = os.environ.get("INT5_LAYERS", "") # e.g. "2,3,4,5,6,7,8" + + sliding_window_stride = int(os.environ.get("SLIDING_WINDOW_STRIDE", 64)) + + # ── LN Scale ── + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # 1/sqrt(layer_idx+1) on norm outputs + + # ── Value Embedding (VE128) ── + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "8,9") # last 2 of 10 layers + + # ── TTT LoRA ── + + # ── Legal Score-First TTT ── + legal_ttt_enabled = bool(int(os.environ.get("LEGAL_TTT_ENABLED", "1"))) + legal_ttt_lr = float(os.environ.get("LEGAL_TTT_LR", 0.002)) + legal_ttt_epochs = int(os.environ.get("LEGAL_TTT_EPOCHS", 3)) + legal_ttt_chunk_tokens = int(os.environ.get("LEGAL_TTT_CHUNK_TOKENS", 32768)) + legal_ttt_freeze_blocks = int(os.environ.get("LEGAL_TTT_FREEZE_BLOCKS", 0)) + legal_ttt_momentum = float(os.environ.get("LEGAL_TTT_MOMENTUM", 0.9)) + legal_ttt_batch_seqs = int(os.environ.get("LEGAL_TTT_BATCH_SEQS", 32)) + legal_ttt_grad_clip = float(os.environ.get("LEGAL_TTT_GRAD_CLIP", 1.0)) + + @property + def num_effective_layers(self) -> int: + return self.num_unique_layers * self.num_recurrences + + def validate(self) -> None: + """Check all divisibility constraints.""" + d = self.model_dim + assert d % 3 == 0, f"model_dim={d} must be divisible by 3 for unified attention split" + comp = d // 3 + assert comp % self.num_heads == 0, ( + f"component_dim={comp} (model_dim/3) must be divisible by num_heads={self.num_heads}" + ) + head_dim = comp // self.num_heads + assert head_dim % 2 == 0, f"head_dim={head_dim} must be even for RoPE" + assert head_dim >= 16, f"head_dim={head_dim} must be >= 16 for useful RoPE (got {head_dim})" + assert self.logit_softcap > 0, f"logit_softcap must be positive" + logger.info(f"Architecture constraints validated: d={d}, comp={comp}, heads={self.num_heads}, " + f"head_dim={head_dim}, RoPE_pairs={head_dim//2}") + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, batched NS5, all-gather.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "mlp_scale", "resid_mix", "skip_weight", "seeking_gain", "smear", "ve_layer_scales", "ve_shared.scale") +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 99.99984 / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_float_tensor_int6(t: Tensor): + return quantize_float_tensor_intN(t, max_val=31) + +def quantize_float_tensor_intN(t: Tensor, max_val: int = 31): + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_err = None, None, float('inf') + + for pct in GPTQ_CLIP_PERCENTILES: + if pct >= 1.0: + clip_abs = t32.abs().amax(dim=1).clamp_min(1e-8) + else: + clip_abs = torch.quantile(t32.abs(), pct, dim=1).clamp_min(1e-8) + scale = (clip_abs / max_val).clamp_min(1e-8).to(torch.float16) + clipped = t32.clamp(-clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -max_val, max_val).to(torch.int8) + recon = q.float() * scale.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_scale, best_err = q, scale, err + + return best_q.contiguous(), best_scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-8).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -max_val, max_val).to(torch.int8) + return q, scale + +# ── Unbank/rebank for quantization ── + +def _unbank_state_dict(sd, num_layers): + out = {} + for name, tensor in sd.items(): + if name == "unified_bank": + for i in range(num_layers): + w = tensor[i] # [d, d] + d = w.shape[0] + comp = d // 3 + out[f"blocks.{i}.attn.W_seeking.weight"] = w[:comp, :] + out[f"blocks.{i}.attn.W_offering.weight"] = w[comp:2*comp, :] + out[f"blocks.{i}.attn.W_content.weight"] = w[2*comp:, :] + elif name == "output_bank": + for i in range(num_layers): + out[f"blocks.{i}.attn.W_output.weight"] = tensor[i] + elif name == "fc_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "proj_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd, num_layers, template_sd): + out = {} + consumed = set() + + unified_slices = [] + for i in range(num_layers): + sk = f"blocks.{i}.attn.W_seeking.weight" + ok = f"blocks.{i}.attn.W_offering.weight" + ck = f"blocks.{i}.attn.W_content.weight" + unified_slices.append(torch.cat([sd[sk], sd[ok], sd[ck]], dim=0)) + consumed.update([sk, ok, ck]) + out["unified_bank"] = torch.stack(unified_slices).to(dtype=template_sd["unified_bank"].dtype) + + for bank_name, key_template in [ + ("output_bank", "blocks.{i}.attn.W_output.weight"), + ("fc_bank", "blocks.{i}.mlp.fc.weight"), + ("proj_bank", "blocks.{i}.mlp.proj.weight"), + ]: + slices = [] + for i in range(num_layers): + k = key_template.format(i=i) + slices.append(sd[k]) + consumed.add(k) + out[bank_name] = torch.stack(slices).to(dtype=template_sd[bank_name].dtype) + + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +INT8_EMBED_PATTERNS = ("tok_emb.", "ve_shared.embed.") + +def quantize_state_dict_mixed(state_dict, int5_layers=None): + if int5_layers is None: + int5_layers = set() + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + result[name] = t.float().contiguous() + meta[name] = "passthrough_ctrl" + else: + result[name] = t.to(torch.float16).contiguous() + meta[name] = "passthrough" + continue + is_embed = any(p in name for p in INT8_EMBED_PATTERNS) + if is_embed: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + layer_idx = -1 + if "blocks." in name: + try: + layer_idx = int(name.split("blocks.")[1].split(".")[0]) + except (ValueError, IndexError): + pass + if layer_idx in int5_layers: + q, s = quantize_float_tensor_intN(t, max_val=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = quantize_float_tensor_int6(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + return result, meta + +def dequantize_state_dict_mixed(result, meta, template_sd=None): + """Dequantize flat-key mixed int6/int8 state dict back to float tensors.""" + out = {} + for name, info in meta.items(): + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if template_sd is not None and name in template_sd: + orig_dtype = template_sd[name].dtype + if t.dtype != orig_dtype: + t = t.to(orig_dtype) + out[name] = t + continue + q = result[name + ".q"] + s = result[name + ".scale"] + if s.ndim > 0: + deq = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + deq = q.float() * float(s.item()) + target_dtype = torch.bfloat16 + if template_sd is not None and name in template_sd: + target_dtype = template_sd[name].dtype + out[name] = deq.to(target_dtype).contiguous() + return out + + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_mlp = False # kept for compatibility + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.training and _qat_active and w.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + w = _fake_quantize(w, _qat_bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + +# ── QAT globals (set during training) ── +_qat_active = False +_qat_bits = 6 + +def _fake_quantize(w: Tensor, bits: int) -> Tensor: + max_val = (1 << (bits - 1)) - 1 # e.g. int6: max_val = 31 + with torch.no_grad(): + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + scale = abs_max / max_val + w_q = (w / scale).round().clamp(-max_val, max_val) * scale + return w + (w_q - w).detach() + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, target_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, target_dim, bias=False) if ve_dim != target_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class UnifiedAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0): + super().__init__() + assert dim % 3 == 0, f"dim={dim} must be divisible by 3" + self.dim = dim + self.num_heads = num_heads + self.component_dim = dim // 3 + self.head_dim = self.component_dim // num_heads + assert self.component_dim % num_heads == 0 + + self.rope_dim = int(self.head_dim * rope_fraction) + self.rope_dim = max(self.rope_dim - (self.rope_dim % 2), 2) + self.pass_dim = self.head_dim - self.rope_dim + + self.seeking_gain = nn.Parameter( + torch.full((num_heads,), seeking_gain_init, dtype=torch.float32) + ) + self.rotary = Rotary(self.rope_dim, base=rope_base) + + def forward(self, x: Tensor, unified_w: Tensor, output_w: Tensor, unified_delta=None, v_embed=None) -> Tensor: + bsz, seqlen, _ = x.shape + + unified = F.linear(x, unified_w.to(x.dtype)) + if unified_delta is not None: + unified = unified + unified_delta + + seeking, offering, content = unified.split(self.component_dim, dim=-1) + + if v_embed is not None: + content = content + v_embed + + def to_heads(t): + return t.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + + seeking = to_heads(seeking) + offering = to_heads(offering) + content = to_heads(content) + + seeking = F.rms_norm(seeking, (seeking.size(-1),)) + offering = F.rms_norm(offering, (offering.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, seeking.dtype) + if self.pass_dim > 0: + s_rope, s_pass = seeking[..., :self.rope_dim], seeking[..., self.rope_dim:] + o_rope, o_pass = offering[..., :self.rope_dim], offering[..., self.rope_dim:] + s_rope = apply_rotary_emb(s_rope, cos, sin) + o_rope = apply_rotary_emb(o_rope, cos, sin) + seeking = torch.cat([s_rope, s_pass], dim=-1) + offering = torch.cat([o_rope, o_pass], dim=-1) + else: + seeking = apply_rotary_emb(seeking, cos, sin) + offering = apply_rotary_emb(offering, cos, sin) + + seeking = seeking * self.seeking_gain.to(dtype=seeking.dtype)[None, :, None, None] + + sq = seeking.transpose(1, 2) + of = offering.transpose(1, 2) + ct = content.transpose(1, 2) + dtype = sq.dtype + if dtype not in (torch.float16, torch.bfloat16): + sq, of, ct = sq.to(torch.bfloat16), of.to(torch.bfloat16), ct.to(torch.bfloat16) + hd = sq.size(-1) + pad_n = (8 - hd % 8) % 8 + if pad_n > 0: + sq = F.pad(sq, (0, pad_n)) + of = F.pad(of, (0, pad_n)) + ct = F.pad(ct, (0, pad_n)) + out = _flash_attn_func(sq, of, ct, causal=True) + y = out[0] if isinstance(out, tuple) else out + if pad_n > 0: + y = y[..., :hd] + if y.dtype != dtype: + y = y.to(dtype) + y = y.transpose(1, 2) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, self.component_dim) + return F.linear(y, output_w.to(x.dtype)) + +class SquaredReLUMLP(nn.Module): + """LeakyReLU(0.5)² MLP — weights passed from banks.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + + def forward(self, x: Tensor, fc_w: Tensor, proj_w: Tensor) -> Tensor: + return F.linear( + F.leaky_relu(F.linear(x, fc_w.to(x.dtype)), negative_slope=0.5).square(), + proj_w.to(x.dtype) + ) + +class Block(nn.Module): + """Single transformer block with unified attention + MLP. Weights from banks.""" + def __init__(self, dim: int, num_heads: int, mlp_mult: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = UnifiedAttention(dim, num_heads, rope_base, seeking_gain_init, rope_fraction) + self.mlp = SquaredReLUMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, unified_w: Tensor, output_w: Tensor, + fc_w: Tensor, proj_w: Tensor, unified_delta_fn=None, v_embed=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) * self.ln_scale_factor + ud = unified_delta_fn(n) if unified_delta_fn is not None else None + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(n, unified_w, output_w, ud, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, fc_w, proj_w) + return x + +class YoctoGPT(nn.Module): + def __init__(self, vocab_size: int, model_dim: int, num_heads: int, + num_unique_layers: int, num_recurrences: int, mlp_mult: int, + tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, seeking_gain_init: float, + rope_fraction: float = 1.0, + ln_scale: bool = True, + ve_enabled: bool = True, ve_dim: int = 128, ve_layers: str = "8,9", + int5_layers: str = ""): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_unique_layers = num_unique_layers + self.num_recurrences = num_recurrences + self.int5_layer_set = set(int(x) for x in int5_layers.split(",") if x.strip()) + effective = num_unique_layers * num_recurrences + + comp_dim = model_dim // 3 + mlp_dim = mlp_mult * model_dim + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = None + self.smear = SmearGate(model_dim) + + K = num_unique_layers + self.unified_bank = nn.Parameter(torch.empty(K, model_dim, model_dim)) # W_unified: d→d + self.output_bank = nn.Parameter(torch.empty(K, model_dim, comp_dim)) # W_output: comp→d (F.linear expects [out, in]) + self.fc_bank = nn.Parameter(torch.empty(K, mlp_dim, model_dim)) # MLP fc: d→mlp_dim + self.proj_bank = nn.Parameter(torch.empty(K, model_dim, mlp_dim)) # MLP proj: mlp_dim→d + + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, mlp_mult, rope_base, seeking_gain_init, rope_fraction, + layer_idx=k, ln_scale=ln_scale) + for k in range(num_unique_layers) + ]) + + self.num_encoder_layers = effective // 2 + self.num_decoder_layers = effective - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, comp_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights(tied_embed_init_std) + + def _init_weights(self, std: float) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=std) + K = self.num_unique_layers + proj_scale = 1.0 / math.sqrt(2 * K * self.num_recurrences) + for i in range(K): + nn.init.orthogonal_(self.unified_bank.data[i], gain=1.0) + nn.init.zeros_(self.output_bank.data[i]) + self.output_bank.data[i].mul_(proj_scale) + nn.init.orthogonal_(self.fc_bank.data[i], gain=1.0) + nn.init.zeros_(self.proj_bank.data[i]) + self.proj_bank.data[i].mul_(proj_scale) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _qat_weight(self, w: Tensor, layer_idx: int = -1) -> Tensor: + if self.training and _qat_active: + bits = 5 if layer_idx in self.int5_layer_set else _qat_bits + return _fake_quantize(w, bits) + return w + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ud_fn = lora.unified_loras[k] if (lora and lora.unified_loras is not None) else None + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self._qat_weight(self.unified_bank[k], k), + self._qat_weight(self.output_bank[k], k), + self._qat_weight(self.fc_bank[k], k), + self._qat_weight(self.proj_bank[k], k), + ud_fn, v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self.unified_bank[k], self.output_bank[k], + self.fc_bank[k], self.proj_bank[k], + v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + +def eval_val_legal_ttt(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=print): + seq_len = args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.legal_ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"legal_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lr={args.legal_ttt_lr} epochs={args.legal_ttt_epochs} " + f"freeze_blocks={args.legal_ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(args.legal_ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"legal_ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.legal_ttt_lr, momentum=args.legal_ttt_momentum) + batch_seqs = args.legal_ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.legal_ttt_epochs > 0: + base_model.train() + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.legal_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.legal_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.legal_ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" legal_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"legal_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def prune_to_fit(result, meta, code_bytes, target_bytes=16_000_000, compress="lzma"): + """Selectively zero ±1 quantized values to fit artifact in budget.""" + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + candidates = [] + for name, info in meta.items(): + if isinstance(info, dict) and info.get("type") in ("int6", "int5"): + q = result[name + ".q"] + s = result[name + ".scale"] + for row in range(q.shape[0]): + mask = (q[row].abs() == 1) + if mask.any(): + scale_sq = float(s[row].float() ** 2) if s.ndim > 0 else float(s.float() ** 2) + count = int(mask.sum().item()) + candidates.append((scale_sq, name, row, count)) + + candidates.sort(key=lambda x: x[0]) + + batch_size = max(1, len(candidates) // 20) + for i in range(0, len(candidates), batch_size): + batch = candidates[i:i + batch_size] + for _, name, row, _ in batch: + q = result[name + ".q"] + mask = (q[row].abs() == 1) + q[row][mask] = 0 + + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + return result, len(blob) + + +def main() -> None: + + try: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + args.validate() + + # ── Distributed + CUDA ── + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + logger.info(f"Log file: {logfile}") + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + logger.info(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + + # ── Tokenizer + Validation ── + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + # ── Model ── + base_model = YoctoGPT( + vocab_size=args.vocab_size, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_unique_layers=args.num_unique_layers, + num_recurrences=args.num_recurrences, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + seeking_gain_init=args.seeking_gain_init, + rope_fraction=args.rope_fraction, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + int5_layers=args.int5_layers, + ).to(device).bfloat16() + + base_model.unified_bank.data = base_model.unified_bank.data.float() + base_model.output_bank.data = base_model.output_bank.data.float() + base_model.fc_bank.data = base_model.fc_bank.data.float() + base_model.proj_bank.data = base_model.proj_bank.data.float() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + + if master_process: + log_architecture(base_model, args) + + try: + _test_mod = torch.compile(lambda q, k, v: _flash_attn_func(q, k, v, causal=True), dynamic=False) + _tq = torch.randn(1, 8, 1, 48, dtype=torch.bfloat16, device=device) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + _test_mod(_tq, _tq, _tq) + log0("torch.compile + FA3: COMPATIBLE") + compiled_model = torch.compile(base_model, dynamic=False) + model = compiled_model + except Exception as e: + log0(f"torch.compile + FA3: INCOMPATIBLE ({type(e).__name__}), running uncompiled") + model = base_model + + log0("attention_backend:fa3") + + # ── Optimizer: banks → Muon, rest → Adam/AdamW ── + matrix_params = [ + base_model.unified_bank, base_model.output_bank, + base_model.fc_bank, base_model.proj_bank, + ] + + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_param_groups.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_param_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + + replicated_params = [base_model.tok_emb.weight] + scalar_params + if base_model.ve_shared is not None: + replicated_params.append(base_model.ve_shared.embed.weight) + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + replicated_params.append(base_model.lm_head.weight) + if base_model.bigram is not None: + bigram_params = list(base_model.bigram.parameters()) + optimizer_bigram = torch.optim.AdamW([{"params": bigram_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers.append(optimizer_bigram) + replicated_params.extend(bigram_params) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} effective_depth:{args.num_effective_layers}") + if base_model.int5_layer_set: + log0(f"mixed_precision: int5_layers={sorted(base_model.int5_layer_set)} int6_layers={sorted(set(range(args.num_unique_layers)) - base_model.int5_layer_set)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + + # ── Data loader + warmup ── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + # ── EMA + SWA shadow weights ── + ema_state = None + swa_params = None + swa_count = 0 + if args.ema_decay > 0: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"EMA enabled: decay={args.ema_decay}") + if args.swa_every > 0: + swa_params = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"SWA enabled: every {args.swa_every} steps when lr_scale < {args.swa_threshold}") + + def update_ema_swa(step, lr_scale): + nonlocal swa_count + with torch.no_grad(): + if ema_state is not None: + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if swa_params is not None and step > 0 and step % args.swa_every == 0: + if lr_scale < args.swa_threshold: + if swa_count == 0: + for name, t in base_model.state_dict().items(): + swa_params[name].copy_(t.detach().cpu()) + swa_count = 1 + log0(f"SWA started at step {step} (lr_scale={lr_scale:.4f})") + else: + for name, t in base_model.state_dict().items(): + swa_params[name] += t.detach().cpu() + swa_count += 1 + + def get_best_weights(): + """Return best averaged weights. EMA preferred (per PR#401).""" + if ema_state is not None: + log0(f"Using EMA weights (decay={args.ema_decay})") + current_state = base_model.state_dict() + return {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + if swa_params is not None and swa_count >= 2: + log0(f"Using SWA weights ({swa_count} checkpoints)") + current_state = base_model.state_dict() + return {name: (t / swa_count).to(dtype=current_state[name].dtype) + for name, t in swa_params.items()} + return None + + def lr_mul(step, elapsed_ms): + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + wl = model(x, y) + (wl * grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if ws + 1 == args.warmup_steps or (ws + 1) % 10 == 0: + log0(f"warmup_step:{ws+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ── + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # ── QAT activation check ── + global _qat_active, _qat_bits + if args.qat_bits > 0 and not _qat_active: + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac = elapsed_ms / max_wallclock_ms + else: + frac = step / max(args.iterations, 1) + if frac >= args.qat_start_fraction: + _qat_active = True + _qat_bits = args.qat_bits + log0(f"QAT enabled: int{args.qat_bits} at step {step} (fraction={frac:.2f})") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + if opt is not optimizer_muon: + opt.step() + optimizer_muon.step() + + update_ema_swa(step, scale) + zero_grad_all() + + step += 1 + approx_time = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_time:.0f}ms step_avg:{approx_time / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_time >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rc = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rc, op=dist.ReduceOp.MAX) + reached_cap = bool(rc.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ── Load best averaged weights (EMA > SWA > raw) ── + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + + # ── Serialization ── + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Raw model: {model_bytes} bytes, code: {code_bytes} bytes") + + # ── Mixed int6/int8 quantization + roundtrip (if QAT was used) ── + if args.qat_bits == 6: + if master_process: + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu"), strict=True) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_unique_layers) + int5_set = set(int(x) for x in args.int5_layers.split(",") if x.strip()) + mixed_result, mixed_meta = quantize_state_dict_mixed(unbanked_sd, int5_layers=int5_set) + code_bytes = len(code.encode("utf-8")) + mixed_result, _ = prune_to_fit(mixed_result, mixed_meta, code_bytes, + target_bytes=16_000_000, compress=args.compression) + mixed_buf = io.BytesIO() + torch.save({"w": mixed_result, "m": mixed_meta}, mixed_buf) + mixed_raw = mixed_buf.getvalue() + if args.compression == "lzma": + mixed_blob = lzma.compress(mixed_raw, preset=6) + mixed_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_blob = zstd_mod.ZstdCompressor(level=22).compress(mixed_raw) + mixed_label = "zstd-22" + except ImportError: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + else: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + if master_process: + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + mixed_bytes = os.path.getsize("final_model.mixed.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"mixed_int6_int8+{mixed_label}: {mixed_bytes} bytes, total: {mixed_bytes + code_bytes} bytes") + if mixed_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {mixed_bytes + code_bytes} > 16,000,000") + else: + log0(f"FITS: {mixed_bytes + code_bytes} <= 16,000,000") + if distributed: + dist.barrier() + with open("final_model.mixed.ptz", "rb") as f: + mixed_qblob = f.read() + if args.compression == "lzma": + mixed_decompressed = lzma.decompress(mixed_qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_decompressed = zstd_mod.ZstdDecompressor().decompress(mixed_qblob) + except ImportError: + mixed_decompressed = zlib.decompress(mixed_qblob) + else: + mixed_decompressed = zlib.decompress(mixed_qblob) + quant_state = torch.load(io.BytesIO(mixed_decompressed), map_location="cpu") + deq_unbanked = dequantize_state_dict_mixed(quant_state["w"], quant_state["m"], unbanked_sd) + deq_sd = _rebank_state_dict(deq_unbanked, args.num_unique_layers, sd_cpu) + base_model.load_state_dict(deq_sd, strict=True) + torch.cuda.synchronize() + qm_val_loss, qm_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_mixed_{mixed_label}_roundtrip val_loss:{qm_val_loss:.4f} val_bpb:{qm_val_bpb:.4f}") + log0(f"final_mixed_{mixed_label}_roundtrip_exact val_loss:{qm_val_loss:.8f} val_bpb:{qm_val_bpb:.8f}") + + # ── Legal Score-First TTT eval ── + if args.legal_ttt_enabled: + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + torch.cuda.synchronize() + t_legal = time.perf_counter() + legal_loss, legal_bpb = eval_val_legal_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=log0) + log0(f"final_legal_ttt val_loss:{legal_loss:.4f} val_bpb:{legal_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_legal):.0f}ms") + + if distributed: + dist.destroy_process_group() + + except Exception: + logger.error(f"FATAL ERROR:\n{traceback.format_exc()}") + raise + +if __name__ == "__main__": + main() +==================================================================================================== +torch.compile + FA3: COMPATIBLE +attention_backend:fa3 +model_params:23209295 effective_depth:11 +world_size:8 grad_accum_steps:1 +EMA enabled: decay=0.997 +SWA enabled: every 50 steps when lr_scale < 0.2 +warmup_step:10/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9321 val_bpb:4.1056 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9330 train_time:89ms step_avg:88.58ms +step:2/20000 train_loss:6.8946 train_time:104ms step_avg:52.18ms +step:3/20000 train_loss:6.7843 train_time:151ms step_avg:50.23ms +step:4/20000 train_loss:6.5824 train_time:197ms step_avg:49.35ms +step:5/20000 train_loss:6.2987 train_time:246ms step_avg:49.11ms +step:6/20000 train_loss:6.1347 train_time:292ms step_avg:48.70ms +step:7/20000 train_loss:5.8964 train_time:339ms step_avg:48.49ms +step:8/20000 train_loss:5.8436 train_time:387ms step_avg:48.35ms +step:9/20000 train_loss:5.7850 train_time:434ms step_avg:48.24ms +step:10/20000 train_loss:5.7466 train_time:482ms step_avg:48.16ms +step:500/20000 train_loss:2.4800 train_time:24113ms step_avg:48.23ms +step:1000/20000 train_loss:2.3689 train_time:48367ms step_avg:48.37ms +step:1500/20000 train_loss:2.2344 train_time:72778ms step_avg:48.52ms +QAT enabled: int6 at step 1852 (fraction=0.15) +step:2000/20000 train_loss:2.2053 train_time:105581ms step_avg:52.79ms +step:2500/20000 train_loss:2.2021 train_time:129951ms step_avg:51.98ms +step:3000/20000 train_loss:3.1688 train_time:154334ms step_avg:51.44ms +step:3000/20000 val_loss:2.1694 val_bpb:1.2848 train_time:154369ms step_avg:51.46ms +step:3500/20000 train_loss:2.3468 train_time:178777ms step_avg:51.08ms +step:4000/20000 train_loss:2.2508 train_time:203171ms step_avg:50.79ms +step:4500/20000 train_loss:1.8515 train_time:227602ms step_avg:50.58ms +step:5000/20000 train_loss:2.2155 train_time:252054ms step_avg:50.41ms +step:5500/20000 train_loss:2.2023 train_time:276469ms step_avg:50.27ms +step:6000/20000 train_loss:2.0951 train_time:300917ms step_avg:50.15ms +step:6000/20000 val_loss:2.1248 val_bpb:1.2584 train_time:300953ms step_avg:50.16ms +step:6500/20000 train_loss:2.0885 train_time:325370ms step_avg:50.06ms +step:7000/20000 train_loss:2.0812 train_time:349779ms step_avg:49.97ms +step:7500/20000 train_loss:2.0987 train_time:374236ms step_avg:49.90ms +step:8000/20000 train_loss:2.0449 train_time:398650ms step_avg:49.83ms +step:8500/20000 train_loss:2.2019 train_time:423118ms step_avg:49.78ms +step:9000/20000 train_loss:2.0883 train_time:447575ms step_avg:49.73ms +step:9000/20000 val_loss:2.0962 val_bpb:1.2415 train_time:447610ms step_avg:49.73ms +step:9500/20000 train_loss:2.0107 train_time:471976ms step_avg:49.68ms +step:10000/20000 train_loss:1.9743 train_time:496428ms step_avg:49.64ms +step:10500/20000 train_loss:1.9789 train_time:520872ms step_avg:49.61ms +step:11000/20000 train_loss:2.0102 train_time:545276ms step_avg:49.57ms +SWA started at step 11450 (lr_scale=0.1884) +step:11500/20000 train_loss:1.8758 train_time:569852ms step_avg:49.55ms +step:12000/20000 train_loss:1.9428 train_time:594795ms step_avg:49.57ms +step:12000/20000 val_loss:1.9664 val_bpb:1.1646 train_time:594831ms step_avg:49.57ms +step:12103/20000 val_loss:1.9638 val_bpb:1.1631 train_time:600040ms step_avg:49.58ms +stopping_early: wallclock_cap train_time:600040ms step:12103/20000 +peak memory: 12569 MiB +Using EMA weights (decay=0.997) +Raw model: 91514419 bytes, code: 75347 bytes +mixed_int6_int8+lzma-6: 15887168 bytes, total: 15962515 bytes +FITS: 15962515 <= 16,000,000 +final_mixed_lzma-6_roundtrip val_loss:1.9645 val_bpb:1.1635 +final_mixed_lzma-6_roundtrip_exact val_loss:1.96449328 val_bpb:1.16348357 +Using EMA weights (decay=0.997) +legal_ttt:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 lr=0.002 epochs=3 freeze_blocks=0 +legal_ttt:params unfrozen=23209295 frozen=0 + legal_ttt_chunk [1/1893] bpb=1.171454 time=0.4s + legal_ttt_chunk [11/1893] bpb=1.162058 time=2.5s + legal_ttt_chunk [21/1893] bpb=1.148567 time=4.7s + legal_ttt_chunk [31/1893] bpb=1.147828 time=6.8s + legal_ttt_chunk [41/1893] bpb=1.134481 time=9.0s + legal_ttt_chunk [51/1893] bpb=1.128362 time=11.1s + legal_ttt_chunk [61/1893] bpb=1.135281 time=13.3s + legal_ttt_chunk [71/1893] bpb=1.134624 time=15.4s + legal_ttt_chunk [81/1893] bpb=1.134320 time=17.6s + legal_ttt_chunk [91/1893] bpb=1.135335 time=19.7s + legal_ttt_chunk [101/1893] bpb=1.139056 time=21.9s + legal_ttt_chunk [111/1893] bpb=1.141180 time=24.1s + legal_ttt_chunk [121/1893] bpb=1.134581 time=26.2s + legal_ttt_chunk [131/1893] bpb=1.134792 time=28.4s + legal_ttt_chunk [141/1893] bpb=1.140411 time=30.6s + legal_ttt_chunk [151/1893] bpb=1.142447 time=32.8s + legal_ttt_chunk [161/1893] bpb=1.142210 time=34.9s + legal_ttt_chunk [171/1893] bpb=1.146717 time=37.1s + legal_ttt_chunk [181/1893] bpb=1.148995 time=39.2s + legal_ttt_chunk [191/1893] bpb=1.156239 time=41.4s + legal_ttt_chunk [201/1893] bpb=1.155479 time=43.5s + legal_ttt_chunk [211/1893] bpb=1.153247 time=45.7s + legal_ttt_chunk [221/1893] bpb=1.154722 time=47.8s + legal_ttt_chunk [231/1893] bpb=1.153329 time=50.0s + legal_ttt_chunk [241/1893] bpb=1.153828 time=52.1s + legal_ttt_chunk [251/1893] bpb=1.153348 time=54.3s + legal_ttt_chunk [261/1893] bpb=1.150403 time=56.5s + legal_ttt_chunk [271/1893] bpb=1.149198 time=58.6s + legal_ttt_chunk [281/1893] bpb=1.150554 time=60.8s + legal_ttt_chunk [291/1893] bpb=1.152311 time=62.9s + legal_ttt_chunk [301/1893] bpb=1.152973 time=65.1s + legal_ttt_chunk [311/1893] bpb=1.155149 time=67.2s + legal_ttt_chunk [321/1893] bpb=1.157076 time=69.4s + legal_ttt_chunk [331/1893] bpb=1.157081 time=71.5s + legal_ttt_chunk [341/1893] bpb=1.156067 time=73.7s + legal_ttt_chunk [351/1893] bpb=1.158401 time=75.9s + legal_ttt_chunk [361/1893] bpb=1.158667 time=78.0s + legal_ttt_chunk [371/1893] bpb=1.157959 time=80.2s + legal_ttt_chunk [381/1893] bpb=1.158066 time=82.3s + legal_ttt_chunk [391/1893] bpb=1.157912 time=84.5s + legal_ttt_chunk [401/1893] bpb=1.155772 time=86.6s + legal_ttt_chunk [411/1893] bpb=1.154661 time=88.8s + legal_ttt_chunk [421/1893] bpb=1.153668 time=90.9s + legal_ttt_chunk [431/1893] bpb=1.153529 time=93.1s + legal_ttt_chunk [441/1893] bpb=1.153843 time=95.3s + legal_ttt_chunk [451/1893] bpb=1.154150 time=97.4s + legal_ttt_chunk [461/1893] bpb=1.153027 time=99.6s + legal_ttt_chunk [471/1893] bpb=1.153693 time=101.7s + legal_ttt_chunk [481/1893] bpb=1.153365 time=103.9s + legal_ttt_chunk [491/1893] bpb=1.152338 time=106.0s + legal_ttt_chunk [501/1893] bpb=1.151912 time=108.2s + legal_ttt_chunk [511/1893] bpb=1.151251 time=110.3s + legal_ttt_chunk [521/1893] bpb=1.149107 time=112.5s + legal_ttt_chunk [531/1893] bpb=1.150281 time=114.6s + legal_ttt_chunk [541/1893] bpb=1.150604 time=116.8s + legal_ttt_chunk [551/1893] bpb=1.149519 time=119.0s + legal_ttt_chunk [561/1893] bpb=1.149960 time=121.1s + legal_ttt_chunk [571/1893] bpb=1.148886 time=123.3s + legal_ttt_chunk [581/1893] bpb=1.148101 time=125.4s + legal_ttt_chunk [591/1893] bpb=1.147435 time=127.6s + legal_ttt_chunk [601/1893] bpb=1.147941 time=129.7s + legal_ttt_chunk [611/1893] bpb=1.147859 time=131.9s + legal_ttt_chunk [621/1893] bpb=1.147708 time=134.1s + legal_ttt_chunk [631/1893] bpb=1.148390 time=136.2s + legal_ttt_chunk [641/1893] bpb=1.148194 time=138.4s + legal_ttt_chunk [651/1893] bpb=1.148359 time=140.5s + legal_ttt_chunk [661/1893] bpb=1.147884 time=142.7s + legal_ttt_chunk [671/1893] bpb=1.148237 time=144.8s + legal_ttt_chunk [681/1893] bpb=1.148893 time=147.0s + legal_ttt_chunk [691/1893] bpb=1.149910 time=149.1s + legal_ttt_chunk [701/1893] bpb=1.149373 time=151.3s + legal_ttt_chunk [711/1893] bpb=1.149388 time=153.5s + legal_ttt_chunk [721/1893] bpb=1.149044 time=155.6s + legal_ttt_chunk [731/1893] bpb=1.149104 time=157.8s + legal_ttt_chunk [741/1893] bpb=1.149175 time=159.9s + legal_ttt_chunk [751/1893] bpb=1.149046 time=162.1s + legal_ttt_chunk [761/1893] bpb=1.148932 time=164.2s + legal_ttt_chunk [771/1893] bpb=1.148640 time=166.4s + legal_ttt_chunk [781/1893] bpb=1.149410 time=168.6s + legal_ttt_chunk [791/1893] bpb=1.149039 time=170.7s + legal_ttt_chunk [801/1893] bpb=1.149322 time=172.9s + legal_ttt_chunk [811/1893] bpb=1.149079 time=175.0s + legal_ttt_chunk [821/1893] bpb=1.148844 time=177.2s + legal_ttt_chunk [831/1893] bpb=1.148671 time=179.3s + legal_ttt_chunk [841/1893] bpb=1.148026 time=181.5s + legal_ttt_chunk [851/1893] bpb=1.147778 time=183.6s + legal_ttt_chunk [861/1893] bpb=1.147547 time=185.8s + legal_ttt_chunk [871/1893] bpb=1.147792 time=187.9s + legal_ttt_chunk [881/1893] bpb=1.147966 time=190.1s + legal_ttt_chunk [891/1893] bpb=1.147577 time=192.3s + legal_ttt_chunk [901/1893] bpb=1.147320 time=194.4s + legal_ttt_chunk [911/1893] bpb=1.147431 time=196.6s + legal_ttt_chunk [921/1893] bpb=1.147910 time=198.7s + legal_ttt_chunk [931/1893] bpb=1.147911 time=200.9s + legal_ttt_chunk [941/1893] bpb=1.147595 time=203.0s + legal_ttt_chunk [951/1893] bpb=1.147992 time=205.2s + legal_ttt_chunk [961/1893] bpb=1.148035 time=207.4s + legal_ttt_chunk [971/1893] bpb=1.148906 time=209.5s + legal_ttt_chunk [981/1893] bpb=1.148964 time=211.7s + legal_ttt_chunk [991/1893] bpb=1.148994 time=213.8s + legal_ttt_chunk [1001/1893] bpb=1.148963 time=216.0s + legal_ttt_chunk [1011/1893] bpb=1.148746 time=218.1s + legal_ttt_chunk [1021/1893] bpb=1.149079 time=220.3s + legal_ttt_chunk [1031/1893] bpb=1.149547 time=222.4s + legal_ttt_chunk [1041/1893] bpb=1.149201 time=224.6s + legal_ttt_chunk [1051/1893] bpb=1.148909 time=226.8s + legal_ttt_chunk [1061/1893] bpb=1.148926 time=228.9s + legal_ttt_chunk [1071/1893] bpb=1.149531 time=231.1s + legal_ttt_chunk [1081/1893] bpb=1.149768 time=233.2s + legal_ttt_chunk [1091/1893] bpb=1.150534 time=235.4s + legal_ttt_chunk [1101/1893] bpb=1.150532 time=237.5s + legal_ttt_chunk [1111/1893] bpb=1.150390 time=239.7s + legal_ttt_chunk [1121/1893] bpb=1.150201 time=241.8s + legal_ttt_chunk [1131/1893] bpb=1.150117 time=244.0s + legal_ttt_chunk [1141/1893] bpb=1.149802 time=246.1s + legal_ttt_chunk [1151/1893] bpb=1.149843 time=248.3s + legal_ttt_chunk [1161/1893] bpb=1.149504 time=250.5s + legal_ttt_chunk [1171/1893] bpb=1.149815 time=252.6s + legal_ttt_chunk [1181/1893] bpb=1.149044 time=254.8s + legal_ttt_chunk [1191/1893] bpb=1.148923 time=256.9s + legal_ttt_chunk [1201/1893] bpb=1.149321 time=259.1s + legal_ttt_chunk [1211/1893] bpb=1.148859 time=261.2s + legal_ttt_chunk [1221/1893] bpb=1.148548 time=263.4s + legal_ttt_chunk [1231/1893] bpb=1.148258 time=265.5s + legal_ttt_chunk [1241/1893] bpb=1.147906 time=267.7s + legal_ttt_chunk [1251/1893] bpb=1.147319 time=269.9s + legal_ttt_chunk [1261/1893] bpb=1.147315 time=272.0s + legal_ttt_chunk [1271/1893] bpb=1.146960 time=274.2s + legal_ttt_chunk [1281/1893] bpb=1.146775 time=276.3s + legal_ttt_chunk [1291/1893] bpb=1.146535 time=278.5s + legal_ttt_chunk [1301/1893] bpb=1.145943 time=280.6s + legal_ttt_chunk [1311/1893] bpb=1.145547 time=282.8s + legal_ttt_chunk [1321/1893] bpb=1.145223 time=284.9s + legal_ttt_chunk [1331/1893] bpb=1.145150 time=287.1s + legal_ttt_chunk [1341/1893] bpb=1.145031 time=289.2s + legal_ttt_chunk [1351/1893] bpb=1.144977 time=291.4s + legal_ttt_chunk [1361/1893] bpb=1.145016 time=293.5s + legal_ttt_chunk [1371/1893] bpb=1.144884 time=295.7s + legal_ttt_chunk [1381/1893] bpb=1.144889 time=297.9s + legal_ttt_chunk [1391/1893] bpb=1.144502 time=300.0s + legal_ttt_chunk [1401/1893] bpb=1.144472 time=302.2s + legal_ttt_chunk [1411/1893] bpb=1.144606 time=304.3s + legal_ttt_chunk [1421/1893] bpb=1.144847 time=306.5s + legal_ttt_chunk [1431/1893] bpb=1.144557 time=308.6s + legal_ttt_chunk [1441/1893] bpb=1.145088 time=310.8s + legal_ttt_chunk [1451/1893] bpb=1.145423 time=312.9s + legal_ttt_chunk [1461/1893] bpb=1.144991 time=315.1s + legal_ttt_chunk [1471/1893] bpb=1.146016 time=317.2s + legal_ttt_chunk [1481/1893] bpb=1.145557 time=319.4s + legal_ttt_chunk [1491/1893] bpb=1.145361 time=321.5s + legal_ttt_chunk [1501/1893] bpb=1.145302 time=323.7s + legal_ttt_chunk [1511/1893] bpb=1.145320 time=325.9s + legal_ttt_chunk [1521/1893] bpb=1.145381 time=328.0s + legal_ttt_chunk [1531/1893] bpb=1.144872 time=330.2s + legal_ttt_chunk [1541/1893] bpb=1.144731 time=332.3s + legal_ttt_chunk [1551/1893] bpb=1.145022 time=334.5s + legal_ttt_chunk [1561/1893] bpb=1.145022 time=336.6s + legal_ttt_chunk [1571/1893] bpb=1.144860 time=338.8s + legal_ttt_chunk [1581/1893] bpb=1.145004 time=340.9s + legal_ttt_chunk [1591/1893] bpb=1.144852 time=343.1s + legal_ttt_chunk [1601/1893] bpb=1.145040 time=345.2s + legal_ttt_chunk [1611/1893] bpb=1.144969 time=347.4s + legal_ttt_chunk [1621/1893] bpb=1.144595 time=349.5s + legal_ttt_chunk [1631/1893] bpb=1.144901 time=351.7s + legal_ttt_chunk [1641/1893] bpb=1.144928 time=353.8s + legal_ttt_chunk [1651/1893] bpb=1.144873 time=356.0s + legal_ttt_chunk [1661/1893] bpb=1.144750 time=358.2s + legal_ttt_chunk [1671/1893] bpb=1.145223 time=360.3s + legal_ttt_chunk [1681/1893] bpb=1.145367 time=362.5s + legal_ttt_chunk [1691/1893] bpb=1.145196 time=364.6s + legal_ttt_chunk [1701/1893] bpb=1.145322 time=366.8s + legal_ttt_chunk [1711/1893] bpb=1.145313 time=368.9s + legal_ttt_chunk [1721/1893] bpb=1.145291 time=371.1s + legal_ttt_chunk [1731/1893] bpb=1.145162 time=373.2s + legal_ttt_chunk [1741/1893] bpb=1.144965 time=375.4s + legal_ttt_chunk [1751/1893] bpb=1.144785 time=377.5s + legal_ttt_chunk [1761/1893] bpb=1.144925 time=379.7s + legal_ttt_chunk [1771/1893] bpb=1.144826 time=381.9s + legal_ttt_chunk [1781/1893] bpb=1.144846 time=384.0s + legal_ttt_chunk [1791/1893] bpb=1.144440 time=386.2s + legal_ttt_chunk [1801/1893] bpb=1.144312 time=388.3s + legal_ttt_chunk [1811/1893] bpb=1.144223 time=390.5s + legal_ttt_chunk [1821/1893] bpb=1.144287 time=392.6s + legal_ttt_chunk [1831/1893] bpb=1.143693 time=394.8s + legal_ttt_chunk [1841/1893] bpb=1.143803 time=396.9s + legal_ttt_chunk [1851/1893] bpb=1.143596 time=399.1s + legal_ttt_chunk [1861/1893] bpb=1.143223 time=401.2s + legal_ttt_chunk [1871/1893] bpb=1.143195 time=403.4s + legal_ttt_chunk [1881/1893] bpb=1.142755 time=405.5s + legal_ttt_chunk [1891/1893] bpb=1.142514 time=407.7s + legal_ttt_chunk [1893/1893] bpb=1.142561 time=408.0s +legal_ttt:done val_loss=1.925313 val_bpb=1.140282 elapsed=408.0s +final_legal_ttt val_loss:1.9253 val_bpb:1.1403 eval_time:408444ms diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_gpt.py b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..8b5c3dbc6c --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_gpt.py @@ -0,0 +1,1644 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import traceback +import uuid +import zlib +import lzma +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_func +except ImportError: + raise ImportError( + "Flash Attention 3 (Hopper) is required. Install with:\n" + " pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280\n" + "Or see requirements.txt for details." + ) + + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("yocto-golf") + +def log_architecture(model, args): + n = sum(p.numel() for p in model.parameters()) + logger.info(f"YOCTO d={args.model_dim} K={args.num_unique_layers} heads={args.num_heads} params={n:,}") + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # ── Yocto architecture ── + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 552)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 10)) + num_recurrences = int(os.environ.get("NUM_RECURRENCES", 1)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + seeking_gain_init = float(os.environ.get("SEEKING_GAIN_INIT", 1.5)) + rope_fraction = float(os.environ.get("ROPE_FRACTION", 1.0)) # 1.0 = full RoPE, 0.5 = half partial RoPE + + # ── Optimizer ── + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # ── LR warmup (actual learning rate ramp, separate from compile warmup) ── + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 100)) + + # ── EMA ── + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # 0 = disabled, 0.997 = SOTA setting + + # ── SWA (Stochastic Weight Averaging) ── + swa_every = int(os.environ.get("SWA_EVERY", 50)) # 0 = disabled, 50 = SOTA setting + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) # only SWA when lr_scale < this + + # ── Compression ── + compression = os.environ.get("COMPRESSION", "lzma") # "zlib", "zstd", or "lzma" + + # ── QAT (Quantization-Aware Training) ── + qat_bits = int(os.environ.get("QAT_BITS", 6)) # 0 = disabled, 6 = int6 QAT + qat_start_fraction = float(os.environ.get("QAT_START_FRACTION", 0.15)) # when to start QAT + + # ── Mixed precision quantization ── + int5_layers = os.environ.get("INT5_LAYERS", "") # e.g. "2,3,4,5,6,7,8" + + sliding_window_stride = int(os.environ.get("SLIDING_WINDOW_STRIDE", 64)) + + # ── LN Scale ── + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # 1/sqrt(layer_idx+1) on norm outputs + + # ── Value Embedding (VE128) ── + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "8,9") # last 2 of 10 layers + + # ── TTT LoRA ── + + # ── Legal Score-First TTT ── + legal_ttt_enabled = bool(int(os.environ.get("LEGAL_TTT_ENABLED", "1"))) + legal_ttt_lr = float(os.environ.get("LEGAL_TTT_LR", 0.002)) + legal_ttt_epochs = int(os.environ.get("LEGAL_TTT_EPOCHS", 3)) + legal_ttt_chunk_tokens = int(os.environ.get("LEGAL_TTT_CHUNK_TOKENS", 32768)) + legal_ttt_freeze_blocks = int(os.environ.get("LEGAL_TTT_FREEZE_BLOCKS", 0)) + legal_ttt_momentum = float(os.environ.get("LEGAL_TTT_MOMENTUM", 0.9)) + legal_ttt_batch_seqs = int(os.environ.get("LEGAL_TTT_BATCH_SEQS", 32)) + legal_ttt_grad_clip = float(os.environ.get("LEGAL_TTT_GRAD_CLIP", 1.0)) + + @property + def num_effective_layers(self) -> int: + return self.num_unique_layers * self.num_recurrences + + def validate(self) -> None: + """Check all divisibility constraints.""" + d = self.model_dim + assert d % 3 == 0, f"model_dim={d} must be divisible by 3 for unified attention split" + comp = d // 3 + assert comp % self.num_heads == 0, ( + f"component_dim={comp} (model_dim/3) must be divisible by num_heads={self.num_heads}" + ) + head_dim = comp // self.num_heads + assert head_dim % 2 == 0, f"head_dim={head_dim} must be even for RoPE" + assert head_dim >= 16, f"head_dim={head_dim} must be >= 16 for useful RoPE (got {head_dim})" + assert self.logit_softcap > 0, f"logit_softcap must be positive" + logger.info(f"Architecture constraints validated: d={d}, comp={comp}, heads={self.num_heads}, " + f"head_dim={head_dim}, RoPE_pairs={head_dim//2}") + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, batched NS5, all-gather.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "mlp_scale", "resid_mix", "skip_weight", "seeking_gain", "smear", "ve_layer_scales", "ve_shared.scale") +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 99.99984 / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_float_tensor_int6(t: Tensor): + return quantize_float_tensor_intN(t, max_val=31) + +def quantize_float_tensor_intN(t: Tensor, max_val: int = 31): + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_err = None, None, float('inf') + + for pct in GPTQ_CLIP_PERCENTILES: + if pct >= 1.0: + clip_abs = t32.abs().amax(dim=1).clamp_min(1e-8) + else: + clip_abs = torch.quantile(t32.abs(), pct, dim=1).clamp_min(1e-8) + scale = (clip_abs / max_val).clamp_min(1e-8).to(torch.float16) + clipped = t32.clamp(-clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -max_val, max_val).to(torch.int8) + recon = q.float() * scale.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_scale, best_err = q, scale, err + + return best_q.contiguous(), best_scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-8).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -max_val, max_val).to(torch.int8) + return q, scale + +# ── Unbank/rebank for quantization ── + +def _unbank_state_dict(sd, num_layers): + out = {} + for name, tensor in sd.items(): + if name == "unified_bank": + for i in range(num_layers): + w = tensor[i] # [d, d] + d = w.shape[0] + comp = d // 3 + out[f"blocks.{i}.attn.W_seeking.weight"] = w[:comp, :] + out[f"blocks.{i}.attn.W_offering.weight"] = w[comp:2*comp, :] + out[f"blocks.{i}.attn.W_content.weight"] = w[2*comp:, :] + elif name == "output_bank": + for i in range(num_layers): + out[f"blocks.{i}.attn.W_output.weight"] = tensor[i] + elif name == "fc_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "proj_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd, num_layers, template_sd): + out = {} + consumed = set() + + unified_slices = [] + for i in range(num_layers): + sk = f"blocks.{i}.attn.W_seeking.weight" + ok = f"blocks.{i}.attn.W_offering.weight" + ck = f"blocks.{i}.attn.W_content.weight" + unified_slices.append(torch.cat([sd[sk], sd[ok], sd[ck]], dim=0)) + consumed.update([sk, ok, ck]) + out["unified_bank"] = torch.stack(unified_slices).to(dtype=template_sd["unified_bank"].dtype) + + for bank_name, key_template in [ + ("output_bank", "blocks.{i}.attn.W_output.weight"), + ("fc_bank", "blocks.{i}.mlp.fc.weight"), + ("proj_bank", "blocks.{i}.mlp.proj.weight"), + ]: + slices = [] + for i in range(num_layers): + k = key_template.format(i=i) + slices.append(sd[k]) + consumed.add(k) + out[bank_name] = torch.stack(slices).to(dtype=template_sd[bank_name].dtype) + + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +INT8_EMBED_PATTERNS = ("tok_emb.", "ve_shared.embed.") + +def quantize_state_dict_mixed(state_dict, int5_layers=None): + if int5_layers is None: + int5_layers = set() + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + result[name] = t.float().contiguous() + meta[name] = "passthrough_ctrl" + else: + result[name] = t.to(torch.float16).contiguous() + meta[name] = "passthrough" + continue + is_embed = any(p in name for p in INT8_EMBED_PATTERNS) + if is_embed: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + layer_idx = -1 + if "blocks." in name: + try: + layer_idx = int(name.split("blocks.")[1].split(".")[0]) + except (ValueError, IndexError): + pass + if layer_idx in int5_layers: + q, s = quantize_float_tensor_intN(t, max_val=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = quantize_float_tensor_int6(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + return result, meta + +def dequantize_state_dict_mixed(result, meta, template_sd=None): + """Dequantize flat-key mixed int6/int8 state dict back to float tensors.""" + out = {} + for name, info in meta.items(): + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if template_sd is not None and name in template_sd: + orig_dtype = template_sd[name].dtype + if t.dtype != orig_dtype: + t = t.to(orig_dtype) + out[name] = t + continue + q = result[name + ".q"] + s = result[name + ".scale"] + if s.ndim > 0: + deq = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + deq = q.float() * float(s.item()) + target_dtype = torch.bfloat16 + if template_sd is not None and name in template_sd: + target_dtype = template_sd[name].dtype + out[name] = deq.to(target_dtype).contiguous() + return out + + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_mlp = False # kept for compatibility + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.training and _qat_active and w.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + w = _fake_quantize(w, _qat_bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + +# ── QAT globals (set during training) ── +_qat_active = False +_qat_bits = 6 + +def _fake_quantize(w: Tensor, bits: int) -> Tensor: + max_val = (1 << (bits - 1)) - 1 # e.g. int6: max_val = 31 + with torch.no_grad(): + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + scale = abs_max / max_val + w_q = (w / scale).round().clamp(-max_val, max_val) * scale + return w + (w_q - w).detach() + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, target_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, target_dim, bias=False) if ve_dim != target_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class UnifiedAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0): + super().__init__() + assert dim % 3 == 0, f"dim={dim} must be divisible by 3" + self.dim = dim + self.num_heads = num_heads + self.component_dim = dim // 3 + self.head_dim = self.component_dim // num_heads + assert self.component_dim % num_heads == 0 + + self.rope_dim = int(self.head_dim * rope_fraction) + self.rope_dim = max(self.rope_dim - (self.rope_dim % 2), 2) + self.pass_dim = self.head_dim - self.rope_dim + + self.seeking_gain = nn.Parameter( + torch.full((num_heads,), seeking_gain_init, dtype=torch.float32) + ) + self.rotary = Rotary(self.rope_dim, base=rope_base) + + def forward(self, x: Tensor, unified_w: Tensor, output_w: Tensor, unified_delta=None, v_embed=None) -> Tensor: + bsz, seqlen, _ = x.shape + + unified = F.linear(x, unified_w.to(x.dtype)) + if unified_delta is not None: + unified = unified + unified_delta + + seeking, offering, content = unified.split(self.component_dim, dim=-1) + + if v_embed is not None: + content = content + v_embed + + def to_heads(t): + return t.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + + seeking = to_heads(seeking) + offering = to_heads(offering) + content = to_heads(content) + + seeking = F.rms_norm(seeking, (seeking.size(-1),)) + offering = F.rms_norm(offering, (offering.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, seeking.dtype) + if self.pass_dim > 0: + s_rope, s_pass = seeking[..., :self.rope_dim], seeking[..., self.rope_dim:] + o_rope, o_pass = offering[..., :self.rope_dim], offering[..., self.rope_dim:] + s_rope = apply_rotary_emb(s_rope, cos, sin) + o_rope = apply_rotary_emb(o_rope, cos, sin) + seeking = torch.cat([s_rope, s_pass], dim=-1) + offering = torch.cat([o_rope, o_pass], dim=-1) + else: + seeking = apply_rotary_emb(seeking, cos, sin) + offering = apply_rotary_emb(offering, cos, sin) + + seeking = seeking * self.seeking_gain.to(dtype=seeking.dtype)[None, :, None, None] + + sq = seeking.transpose(1, 2) + of = offering.transpose(1, 2) + ct = content.transpose(1, 2) + dtype = sq.dtype + if dtype not in (torch.float16, torch.bfloat16): + sq, of, ct = sq.to(torch.bfloat16), of.to(torch.bfloat16), ct.to(torch.bfloat16) + hd = sq.size(-1) + pad_n = (8 - hd % 8) % 8 + if pad_n > 0: + sq = F.pad(sq, (0, pad_n)) + of = F.pad(of, (0, pad_n)) + ct = F.pad(ct, (0, pad_n)) + out = _flash_attn_func(sq, of, ct, causal=True) + y = out[0] if isinstance(out, tuple) else out + if pad_n > 0: + y = y[..., :hd] + if y.dtype != dtype: + y = y.to(dtype) + y = y.transpose(1, 2) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, self.component_dim) + return F.linear(y, output_w.to(x.dtype)) + +class SquaredReLUMLP(nn.Module): + """LeakyReLU(0.5)² MLP — weights passed from banks.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + + def forward(self, x: Tensor, fc_w: Tensor, proj_w: Tensor) -> Tensor: + return F.linear( + F.leaky_relu(F.linear(x, fc_w.to(x.dtype)), negative_slope=0.5).square(), + proj_w.to(x.dtype) + ) + +class Block(nn.Module): + """Single transformer block with unified attention + MLP. Weights from banks.""" + def __init__(self, dim: int, num_heads: int, mlp_mult: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = UnifiedAttention(dim, num_heads, rope_base, seeking_gain_init, rope_fraction) + self.mlp = SquaredReLUMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, unified_w: Tensor, output_w: Tensor, + fc_w: Tensor, proj_w: Tensor, unified_delta_fn=None, v_embed=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) * self.ln_scale_factor + ud = unified_delta_fn(n) if unified_delta_fn is not None else None + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(n, unified_w, output_w, ud, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, fc_w, proj_w) + return x + +class YoctoGPT(nn.Module): + def __init__(self, vocab_size: int, model_dim: int, num_heads: int, + num_unique_layers: int, num_recurrences: int, mlp_mult: int, + tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, seeking_gain_init: float, + rope_fraction: float = 1.0, + ln_scale: bool = True, + ve_enabled: bool = True, ve_dim: int = 128, ve_layers: str = "8,9", + int5_layers: str = ""): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_unique_layers = num_unique_layers + self.num_recurrences = num_recurrences + self.int5_layer_set = set(int(x) for x in int5_layers.split(",") if x.strip()) + effective = num_unique_layers * num_recurrences + + comp_dim = model_dim // 3 + mlp_dim = mlp_mult * model_dim + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = None + self.smear = SmearGate(model_dim) + + K = num_unique_layers + self.unified_bank = nn.Parameter(torch.empty(K, model_dim, model_dim)) # W_unified: d→d + self.output_bank = nn.Parameter(torch.empty(K, model_dim, comp_dim)) # W_output: comp→d (F.linear expects [out, in]) + self.fc_bank = nn.Parameter(torch.empty(K, mlp_dim, model_dim)) # MLP fc: d→mlp_dim + self.proj_bank = nn.Parameter(torch.empty(K, model_dim, mlp_dim)) # MLP proj: mlp_dim→d + + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, mlp_mult, rope_base, seeking_gain_init, rope_fraction, + layer_idx=k, ln_scale=ln_scale) + for k in range(num_unique_layers) + ]) + + self.num_encoder_layers = effective // 2 + self.num_decoder_layers = effective - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, comp_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights(tied_embed_init_std) + + def _init_weights(self, std: float) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=std) + K = self.num_unique_layers + proj_scale = 1.0 / math.sqrt(2 * K * self.num_recurrences) + for i in range(K): + nn.init.orthogonal_(self.unified_bank.data[i], gain=1.0) + nn.init.zeros_(self.output_bank.data[i]) + self.output_bank.data[i].mul_(proj_scale) + nn.init.orthogonal_(self.fc_bank.data[i], gain=1.0) + nn.init.zeros_(self.proj_bank.data[i]) + self.proj_bank.data[i].mul_(proj_scale) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _qat_weight(self, w: Tensor, layer_idx: int = -1) -> Tensor: + if self.training and _qat_active: + bits = 5 if layer_idx in self.int5_layer_set else _qat_bits + return _fake_quantize(w, bits) + return w + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ud_fn = lora.unified_loras[k] if (lora and lora.unified_loras is not None) else None + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self._qat_weight(self.unified_bank[k], k), + self._qat_weight(self.output_bank[k], k), + self._qat_weight(self.fc_bank[k], k), + self._qat_weight(self.proj_bank[k], k), + ud_fn, v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self.unified_bank[k], self.output_bank[k], + self.fc_bank[k], self.proj_bank[k], + v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + +def eval_val_legal_ttt(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=print): + seq_len = args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.legal_ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"legal_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lr={args.legal_ttt_lr} epochs={args.legal_ttt_epochs} " + f"freeze_blocks={args.legal_ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(args.legal_ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"legal_ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.legal_ttt_lr, momentum=args.legal_ttt_momentum) + batch_seqs = args.legal_ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.legal_ttt_epochs > 0: + base_model.train() + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.legal_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.legal_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.legal_ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" legal_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"legal_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def prune_to_fit(result, meta, code_bytes, target_bytes=16_000_000, compress="lzma"): + """Selectively zero ±1 quantized values to fit artifact in budget.""" + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + candidates = [] + for name, info in meta.items(): + if isinstance(info, dict) and info.get("type") in ("int6", "int5"): + q = result[name + ".q"] + s = result[name + ".scale"] + for row in range(q.shape[0]): + mask = (q[row].abs() == 1) + if mask.any(): + scale_sq = float(s[row].float() ** 2) if s.ndim > 0 else float(s.float() ** 2) + count = int(mask.sum().item()) + candidates.append((scale_sq, name, row, count)) + + candidates.sort(key=lambda x: x[0]) + + batch_size = max(1, len(candidates) // 20) + for i in range(0, len(candidates), batch_size): + batch = candidates[i:i + batch_size] + for _, name, row, _ in batch: + q = result[name + ".q"] + mask = (q[row].abs() == 1) + q[row][mask] = 0 + + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + return result, len(blob) + + +def main() -> None: + + try: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + args.validate() + + # ── Distributed + CUDA ── + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + logger.info(f"Log file: {logfile}") + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + logger.info(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + + # ── Tokenizer + Validation ── + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + # ── Model ── + base_model = YoctoGPT( + vocab_size=args.vocab_size, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_unique_layers=args.num_unique_layers, + num_recurrences=args.num_recurrences, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + seeking_gain_init=args.seeking_gain_init, + rope_fraction=args.rope_fraction, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + int5_layers=args.int5_layers, + ).to(device).bfloat16() + + base_model.unified_bank.data = base_model.unified_bank.data.float() + base_model.output_bank.data = base_model.output_bank.data.float() + base_model.fc_bank.data = base_model.fc_bank.data.float() + base_model.proj_bank.data = base_model.proj_bank.data.float() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + + if master_process: + log_architecture(base_model, args) + + try: + _test_mod = torch.compile(lambda q, k, v: _flash_attn_func(q, k, v, causal=True), dynamic=False) + _tq = torch.randn(1, 8, 1, 48, dtype=torch.bfloat16, device=device) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + _test_mod(_tq, _tq, _tq) + log0("torch.compile + FA3: COMPATIBLE") + compiled_model = torch.compile(base_model, dynamic=False) + model = compiled_model + except Exception as e: + log0(f"torch.compile + FA3: INCOMPATIBLE ({type(e).__name__}), running uncompiled") + model = base_model + + log0("attention_backend:fa3") + + # ── Optimizer: banks → Muon, rest → Adam/AdamW ── + matrix_params = [ + base_model.unified_bank, base_model.output_bank, + base_model.fc_bank, base_model.proj_bank, + ] + + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_param_groups.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_param_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + + replicated_params = [base_model.tok_emb.weight] + scalar_params + if base_model.ve_shared is not None: + replicated_params.append(base_model.ve_shared.embed.weight) + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + replicated_params.append(base_model.lm_head.weight) + if base_model.bigram is not None: + bigram_params = list(base_model.bigram.parameters()) + optimizer_bigram = torch.optim.AdamW([{"params": bigram_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers.append(optimizer_bigram) + replicated_params.extend(bigram_params) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} effective_depth:{args.num_effective_layers}") + if base_model.int5_layer_set: + log0(f"mixed_precision: int5_layers={sorted(base_model.int5_layer_set)} int6_layers={sorted(set(range(args.num_unique_layers)) - base_model.int5_layer_set)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + + # ── Data loader + warmup ── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + # ── EMA + SWA shadow weights ── + ema_state = None + swa_params = None + swa_count = 0 + if args.ema_decay > 0: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"EMA enabled: decay={args.ema_decay}") + if args.swa_every > 0: + swa_params = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"SWA enabled: every {args.swa_every} steps when lr_scale < {args.swa_threshold}") + + def update_ema_swa(step, lr_scale): + nonlocal swa_count + with torch.no_grad(): + if ema_state is not None: + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if swa_params is not None and step > 0 and step % args.swa_every == 0: + if lr_scale < args.swa_threshold: + if swa_count == 0: + for name, t in base_model.state_dict().items(): + swa_params[name].copy_(t.detach().cpu()) + swa_count = 1 + log0(f"SWA started at step {step} (lr_scale={lr_scale:.4f})") + else: + for name, t in base_model.state_dict().items(): + swa_params[name] += t.detach().cpu() + swa_count += 1 + + def get_best_weights(): + """Return best averaged weights. EMA preferred (per PR#401).""" + if ema_state is not None: + log0(f"Using EMA weights (decay={args.ema_decay})") + current_state = base_model.state_dict() + return {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + if swa_params is not None and swa_count >= 2: + log0(f"Using SWA weights ({swa_count} checkpoints)") + current_state = base_model.state_dict() + return {name: (t / swa_count).to(dtype=current_state[name].dtype) + for name, t in swa_params.items()} + return None + + def lr_mul(step, elapsed_ms): + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + wl = model(x, y) + (wl * grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if ws + 1 == args.warmup_steps or (ws + 1) % 10 == 0: + log0(f"warmup_step:{ws+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ── + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # ── QAT activation check ── + global _qat_active, _qat_bits + if args.qat_bits > 0 and not _qat_active: + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac = elapsed_ms / max_wallclock_ms + else: + frac = step / max(args.iterations, 1) + if frac >= args.qat_start_fraction: + _qat_active = True + _qat_bits = args.qat_bits + log0(f"QAT enabled: int{args.qat_bits} at step {step} (fraction={frac:.2f})") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + if opt is not optimizer_muon: + opt.step() + optimizer_muon.step() + + update_ema_swa(step, scale) + zero_grad_all() + + step += 1 + approx_time = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_time:.0f}ms step_avg:{approx_time / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_time >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rc = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rc, op=dist.ReduceOp.MAX) + reached_cap = bool(rc.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ── Load best averaged weights (EMA > SWA > raw) ── + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + + # ── Serialization ── + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Raw model: {model_bytes} bytes, code: {code_bytes} bytes") + + # ── Mixed int6/int8 quantization + roundtrip (if QAT was used) ── + if args.qat_bits == 6: + if master_process: + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu"), strict=True) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_unique_layers) + int5_set = set(int(x) for x in args.int5_layers.split(",") if x.strip()) + mixed_result, mixed_meta = quantize_state_dict_mixed(unbanked_sd, int5_layers=int5_set) + code_bytes = len(code.encode("utf-8")) + mixed_result, _ = prune_to_fit(mixed_result, mixed_meta, code_bytes, + target_bytes=16_000_000, compress=args.compression) + mixed_buf = io.BytesIO() + torch.save({"w": mixed_result, "m": mixed_meta}, mixed_buf) + mixed_raw = mixed_buf.getvalue() + if args.compression == "lzma": + mixed_blob = lzma.compress(mixed_raw, preset=6) + mixed_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_blob = zstd_mod.ZstdCompressor(level=22).compress(mixed_raw) + mixed_label = "zstd-22" + except ImportError: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + else: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + if master_process: + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + mixed_bytes = os.path.getsize("final_model.mixed.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"mixed_int6_int8+{mixed_label}: {mixed_bytes} bytes, total: {mixed_bytes + code_bytes} bytes") + if mixed_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {mixed_bytes + code_bytes} > 16,000,000") + else: + log0(f"FITS: {mixed_bytes + code_bytes} <= 16,000,000") + if distributed: + dist.barrier() + with open("final_model.mixed.ptz", "rb") as f: + mixed_qblob = f.read() + if args.compression == "lzma": + mixed_decompressed = lzma.decompress(mixed_qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_decompressed = zstd_mod.ZstdDecompressor().decompress(mixed_qblob) + except ImportError: + mixed_decompressed = zlib.decompress(mixed_qblob) + else: + mixed_decompressed = zlib.decompress(mixed_qblob) + quant_state = torch.load(io.BytesIO(mixed_decompressed), map_location="cpu") + deq_unbanked = dequantize_state_dict_mixed(quant_state["w"], quant_state["m"], unbanked_sd) + deq_sd = _rebank_state_dict(deq_unbanked, args.num_unique_layers, sd_cpu) + base_model.load_state_dict(deq_sd, strict=True) + torch.cuda.synchronize() + qm_val_loss, qm_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_mixed_{mixed_label}_roundtrip val_loss:{qm_val_loss:.4f} val_bpb:{qm_val_bpb:.4f}") + log0(f"final_mixed_{mixed_label}_roundtrip_exact val_loss:{qm_val_loss:.8f} val_bpb:{qm_val_bpb:.8f}") + + # ── Legal Score-First TTT eval ── + if args.legal_ttt_enabled: + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + torch.cuda.synchronize() + t_legal = time.perf_counter() + legal_loss, legal_bpb = eval_val_legal_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=log0) + log0(f"final_legal_ttt val_loss:{legal_loss:.4f} val_bpb:{legal_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_legal):.0f}ms") + + if distributed: + dist.destroy_process_group() + + except Exception: + logger.error(f"FATAL ERROR:\n{traceback.format_exc()}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed1337.log b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed1337.log new file mode 100644 index 0000000000..858cc92186 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed1337.log @@ -0,0 +1,1899 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import traceback +import uuid +import zlib +import lzma +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_func +except ImportError: + raise ImportError( + "Flash Attention 3 (Hopper) is required. Install with:\n" + " pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280\n" + "Or see requirements.txt for details." + ) + + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("yocto-golf") + +def log_architecture(model, args): + n = sum(p.numel() for p in model.parameters()) + logger.info(f"YOCTO d={args.model_dim} K={args.num_unique_layers} heads={args.num_heads} params={n:,}") + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # ── Yocto architecture ── + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 552)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 10)) + num_recurrences = int(os.environ.get("NUM_RECURRENCES", 1)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + seeking_gain_init = float(os.environ.get("SEEKING_GAIN_INIT", 1.5)) + rope_fraction = float(os.environ.get("ROPE_FRACTION", 1.0)) # 1.0 = full RoPE, 0.5 = half partial RoPE + + # ── Optimizer ── + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # ── LR warmup (actual learning rate ramp, separate from compile warmup) ── + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 100)) + + # ── EMA ── + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # 0 = disabled, 0.997 = SOTA setting + + # ── SWA (Stochastic Weight Averaging) ── + swa_every = int(os.environ.get("SWA_EVERY", 50)) # 0 = disabled, 50 = SOTA setting + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) # only SWA when lr_scale < this + + # ── Compression ── + compression = os.environ.get("COMPRESSION", "lzma") # "zlib", "zstd", or "lzma" + + # ── QAT (Quantization-Aware Training) ── + qat_bits = int(os.environ.get("QAT_BITS", 6)) # 0 = disabled, 6 = int6 QAT + qat_start_fraction = float(os.environ.get("QAT_START_FRACTION", 0.15)) # when to start QAT + + # ── Mixed precision quantization ── + int5_layers = os.environ.get("INT5_LAYERS", "") # e.g. "2,3,4,5,6,7,8" + + sliding_window_stride = int(os.environ.get("SLIDING_WINDOW_STRIDE", 64)) + + # ── LN Scale ── + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # 1/sqrt(layer_idx+1) on norm outputs + + # ── Value Embedding (VE128) ── + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "8,9") # last 2 of 10 layers + + # ── TTT LoRA ── + + # ── Legal Score-First TTT ── + legal_ttt_enabled = bool(int(os.environ.get("LEGAL_TTT_ENABLED", "1"))) + legal_ttt_lr = float(os.environ.get("LEGAL_TTT_LR", 0.002)) + legal_ttt_epochs = int(os.environ.get("LEGAL_TTT_EPOCHS", 3)) + legal_ttt_chunk_tokens = int(os.environ.get("LEGAL_TTT_CHUNK_TOKENS", 32768)) + legal_ttt_freeze_blocks = int(os.environ.get("LEGAL_TTT_FREEZE_BLOCKS", 0)) + legal_ttt_momentum = float(os.environ.get("LEGAL_TTT_MOMENTUM", 0.9)) + legal_ttt_batch_seqs = int(os.environ.get("LEGAL_TTT_BATCH_SEQS", 32)) + legal_ttt_grad_clip = float(os.environ.get("LEGAL_TTT_GRAD_CLIP", 1.0)) + + @property + def num_effective_layers(self) -> int: + return self.num_unique_layers * self.num_recurrences + + def validate(self) -> None: + """Check all divisibility constraints.""" + d = self.model_dim + assert d % 3 == 0, f"model_dim={d} must be divisible by 3 for unified attention split" + comp = d // 3 + assert comp % self.num_heads == 0, ( + f"component_dim={comp} (model_dim/3) must be divisible by num_heads={self.num_heads}" + ) + head_dim = comp // self.num_heads + assert head_dim % 2 == 0, f"head_dim={head_dim} must be even for RoPE" + assert head_dim >= 16, f"head_dim={head_dim} must be >= 16 for useful RoPE (got {head_dim})" + assert self.logit_softcap > 0, f"logit_softcap must be positive" + logger.info(f"Architecture constraints validated: d={d}, comp={comp}, heads={self.num_heads}, " + f"head_dim={head_dim}, RoPE_pairs={head_dim//2}") + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, batched NS5, all-gather.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "mlp_scale", "resid_mix", "skip_weight", "seeking_gain", "smear", "ve_layer_scales", "ve_shared.scale") +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 99.99984 / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_float_tensor_int6(t: Tensor): + return quantize_float_tensor_intN(t, max_val=31) + +def quantize_float_tensor_intN(t: Tensor, max_val: int = 31): + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_err = None, None, float('inf') + + for pct in GPTQ_CLIP_PERCENTILES: + if pct >= 1.0: + clip_abs = t32.abs().amax(dim=1).clamp_min(1e-8) + else: + clip_abs = torch.quantile(t32.abs(), pct, dim=1).clamp_min(1e-8) + scale = (clip_abs / max_val).clamp_min(1e-8).to(torch.float16) + clipped = t32.clamp(-clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -max_val, max_val).to(torch.int8) + recon = q.float() * scale.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_scale, best_err = q, scale, err + + return best_q.contiguous(), best_scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-8).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -max_val, max_val).to(torch.int8) + return q, scale + +# ── Unbank/rebank for quantization ── + +def _unbank_state_dict(sd, num_layers): + out = {} + for name, tensor in sd.items(): + if name == "unified_bank": + for i in range(num_layers): + w = tensor[i] # [d, d] + d = w.shape[0] + comp = d // 3 + out[f"blocks.{i}.attn.W_seeking.weight"] = w[:comp, :] + out[f"blocks.{i}.attn.W_offering.weight"] = w[comp:2*comp, :] + out[f"blocks.{i}.attn.W_content.weight"] = w[2*comp:, :] + elif name == "output_bank": + for i in range(num_layers): + out[f"blocks.{i}.attn.W_output.weight"] = tensor[i] + elif name == "fc_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "proj_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd, num_layers, template_sd): + out = {} + consumed = set() + + unified_slices = [] + for i in range(num_layers): + sk = f"blocks.{i}.attn.W_seeking.weight" + ok = f"blocks.{i}.attn.W_offering.weight" + ck = f"blocks.{i}.attn.W_content.weight" + unified_slices.append(torch.cat([sd[sk], sd[ok], sd[ck]], dim=0)) + consumed.update([sk, ok, ck]) + out["unified_bank"] = torch.stack(unified_slices).to(dtype=template_sd["unified_bank"].dtype) + + for bank_name, key_template in [ + ("output_bank", "blocks.{i}.attn.W_output.weight"), + ("fc_bank", "blocks.{i}.mlp.fc.weight"), + ("proj_bank", "blocks.{i}.mlp.proj.weight"), + ]: + slices = [] + for i in range(num_layers): + k = key_template.format(i=i) + slices.append(sd[k]) + consumed.add(k) + out[bank_name] = torch.stack(slices).to(dtype=template_sd[bank_name].dtype) + + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +INT8_EMBED_PATTERNS = ("tok_emb.", "ve_shared.embed.") + +def quantize_state_dict_mixed(state_dict, int5_layers=None): + if int5_layers is None: + int5_layers = set() + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + result[name] = t.float().contiguous() + meta[name] = "passthrough_ctrl" + else: + result[name] = t.to(torch.float16).contiguous() + meta[name] = "passthrough" + continue + is_embed = any(p in name for p in INT8_EMBED_PATTERNS) + if is_embed: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + layer_idx = -1 + if "blocks." in name: + try: + layer_idx = int(name.split("blocks.")[1].split(".")[0]) + except (ValueError, IndexError): + pass + if layer_idx in int5_layers: + q, s = quantize_float_tensor_intN(t, max_val=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = quantize_float_tensor_int6(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + return result, meta + +def dequantize_state_dict_mixed(result, meta, template_sd=None): + """Dequantize flat-key mixed int6/int8 state dict back to float tensors.""" + out = {} + for name, info in meta.items(): + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if template_sd is not None and name in template_sd: + orig_dtype = template_sd[name].dtype + if t.dtype != orig_dtype: + t = t.to(orig_dtype) + out[name] = t + continue + q = result[name + ".q"] + s = result[name + ".scale"] + if s.ndim > 0: + deq = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + deq = q.float() * float(s.item()) + target_dtype = torch.bfloat16 + if template_sd is not None and name in template_sd: + target_dtype = template_sd[name].dtype + out[name] = deq.to(target_dtype).contiguous() + return out + + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_mlp = False # kept for compatibility + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.training and _qat_active and w.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + w = _fake_quantize(w, _qat_bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + +# ── QAT globals (set during training) ── +_qat_active = False +_qat_bits = 6 + +def _fake_quantize(w: Tensor, bits: int) -> Tensor: + max_val = (1 << (bits - 1)) - 1 # e.g. int6: max_val = 31 + with torch.no_grad(): + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + scale = abs_max / max_val + w_q = (w / scale).round().clamp(-max_val, max_val) * scale + return w + (w_q - w).detach() + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, target_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, target_dim, bias=False) if ve_dim != target_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class UnifiedAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0): + super().__init__() + assert dim % 3 == 0, f"dim={dim} must be divisible by 3" + self.dim = dim + self.num_heads = num_heads + self.component_dim = dim // 3 + self.head_dim = self.component_dim // num_heads + assert self.component_dim % num_heads == 0 + + self.rope_dim = int(self.head_dim * rope_fraction) + self.rope_dim = max(self.rope_dim - (self.rope_dim % 2), 2) + self.pass_dim = self.head_dim - self.rope_dim + + self.seeking_gain = nn.Parameter( + torch.full((num_heads,), seeking_gain_init, dtype=torch.float32) + ) + self.rotary = Rotary(self.rope_dim, base=rope_base) + + def forward(self, x: Tensor, unified_w: Tensor, output_w: Tensor, unified_delta=None, v_embed=None) -> Tensor: + bsz, seqlen, _ = x.shape + + unified = F.linear(x, unified_w.to(x.dtype)) + if unified_delta is not None: + unified = unified + unified_delta + + seeking, offering, content = unified.split(self.component_dim, dim=-1) + + if v_embed is not None: + content = content + v_embed + + def to_heads(t): + return t.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + + seeking = to_heads(seeking) + offering = to_heads(offering) + content = to_heads(content) + + seeking = F.rms_norm(seeking, (seeking.size(-1),)) + offering = F.rms_norm(offering, (offering.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, seeking.dtype) + if self.pass_dim > 0: + s_rope, s_pass = seeking[..., :self.rope_dim], seeking[..., self.rope_dim:] + o_rope, o_pass = offering[..., :self.rope_dim], offering[..., self.rope_dim:] + s_rope = apply_rotary_emb(s_rope, cos, sin) + o_rope = apply_rotary_emb(o_rope, cos, sin) + seeking = torch.cat([s_rope, s_pass], dim=-1) + offering = torch.cat([o_rope, o_pass], dim=-1) + else: + seeking = apply_rotary_emb(seeking, cos, sin) + offering = apply_rotary_emb(offering, cos, sin) + + seeking = seeking * self.seeking_gain.to(dtype=seeking.dtype)[None, :, None, None] + + sq = seeking.transpose(1, 2) + of = offering.transpose(1, 2) + ct = content.transpose(1, 2) + dtype = sq.dtype + if dtype not in (torch.float16, torch.bfloat16): + sq, of, ct = sq.to(torch.bfloat16), of.to(torch.bfloat16), ct.to(torch.bfloat16) + hd = sq.size(-1) + pad_n = (8 - hd % 8) % 8 + if pad_n > 0: + sq = F.pad(sq, (0, pad_n)) + of = F.pad(of, (0, pad_n)) + ct = F.pad(ct, (0, pad_n)) + out = _flash_attn_func(sq, of, ct, causal=True) + y = out[0] if isinstance(out, tuple) else out + if pad_n > 0: + y = y[..., :hd] + if y.dtype != dtype: + y = y.to(dtype) + y = y.transpose(1, 2) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, self.component_dim) + return F.linear(y, output_w.to(x.dtype)) + +class SquaredReLUMLP(nn.Module): + """LeakyReLU(0.5)² MLP — weights passed from banks.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + + def forward(self, x: Tensor, fc_w: Tensor, proj_w: Tensor) -> Tensor: + return F.linear( + F.leaky_relu(F.linear(x, fc_w.to(x.dtype)), negative_slope=0.5).square(), + proj_w.to(x.dtype) + ) + +class Block(nn.Module): + """Single transformer block with unified attention + MLP. Weights from banks.""" + def __init__(self, dim: int, num_heads: int, mlp_mult: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = UnifiedAttention(dim, num_heads, rope_base, seeking_gain_init, rope_fraction) + self.mlp = SquaredReLUMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, unified_w: Tensor, output_w: Tensor, + fc_w: Tensor, proj_w: Tensor, unified_delta_fn=None, v_embed=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) * self.ln_scale_factor + ud = unified_delta_fn(n) if unified_delta_fn is not None else None + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(n, unified_w, output_w, ud, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, fc_w, proj_w) + return x + +class YoctoGPT(nn.Module): + def __init__(self, vocab_size: int, model_dim: int, num_heads: int, + num_unique_layers: int, num_recurrences: int, mlp_mult: int, + tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, seeking_gain_init: float, + rope_fraction: float = 1.0, + ln_scale: bool = True, + ve_enabled: bool = True, ve_dim: int = 128, ve_layers: str = "8,9", + int5_layers: str = ""): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_unique_layers = num_unique_layers + self.num_recurrences = num_recurrences + self.int5_layer_set = set(int(x) for x in int5_layers.split(",") if x.strip()) + effective = num_unique_layers * num_recurrences + + comp_dim = model_dim // 3 + mlp_dim = mlp_mult * model_dim + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = None + self.smear = SmearGate(model_dim) + + K = num_unique_layers + self.unified_bank = nn.Parameter(torch.empty(K, model_dim, model_dim)) # W_unified: d→d + self.output_bank = nn.Parameter(torch.empty(K, model_dim, comp_dim)) # W_output: comp→d (F.linear expects [out, in]) + self.fc_bank = nn.Parameter(torch.empty(K, mlp_dim, model_dim)) # MLP fc: d→mlp_dim + self.proj_bank = nn.Parameter(torch.empty(K, model_dim, mlp_dim)) # MLP proj: mlp_dim→d + + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, mlp_mult, rope_base, seeking_gain_init, rope_fraction, + layer_idx=k, ln_scale=ln_scale) + for k in range(num_unique_layers) + ]) + + self.num_encoder_layers = effective // 2 + self.num_decoder_layers = effective - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, comp_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights(tied_embed_init_std) + + def _init_weights(self, std: float) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=std) + K = self.num_unique_layers + proj_scale = 1.0 / math.sqrt(2 * K * self.num_recurrences) + for i in range(K): + nn.init.orthogonal_(self.unified_bank.data[i], gain=1.0) + nn.init.zeros_(self.output_bank.data[i]) + self.output_bank.data[i].mul_(proj_scale) + nn.init.orthogonal_(self.fc_bank.data[i], gain=1.0) + nn.init.zeros_(self.proj_bank.data[i]) + self.proj_bank.data[i].mul_(proj_scale) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _qat_weight(self, w: Tensor, layer_idx: int = -1) -> Tensor: + if self.training and _qat_active: + bits = 5 if layer_idx in self.int5_layer_set else _qat_bits + return _fake_quantize(w, bits) + return w + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ud_fn = lora.unified_loras[k] if (lora and lora.unified_loras is not None) else None + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self._qat_weight(self.unified_bank[k], k), + self._qat_weight(self.output_bank[k], k), + self._qat_weight(self.fc_bank[k], k), + self._qat_weight(self.proj_bank[k], k), + ud_fn, v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self.unified_bank[k], self.output_bank[k], + self.fc_bank[k], self.proj_bank[k], + v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + +def eval_val_legal_ttt(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=print): + seq_len = args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.legal_ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"legal_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lr={args.legal_ttt_lr} epochs={args.legal_ttt_epochs} " + f"freeze_blocks={args.legal_ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(args.legal_ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"legal_ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.legal_ttt_lr, momentum=args.legal_ttt_momentum) + batch_seqs = args.legal_ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.legal_ttt_epochs > 0: + base_model.train() + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.legal_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.legal_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.legal_ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" legal_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"legal_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def prune_to_fit(result, meta, code_bytes, target_bytes=16_000_000, compress="lzma"): + """Selectively zero ±1 quantized values to fit artifact in budget.""" + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + candidates = [] + for name, info in meta.items(): + if isinstance(info, dict) and info.get("type") in ("int6", "int5"): + q = result[name + ".q"] + s = result[name + ".scale"] + for row in range(q.shape[0]): + mask = (q[row].abs() == 1) + if mask.any(): + scale_sq = float(s[row].float() ** 2) if s.ndim > 0 else float(s.float() ** 2) + count = int(mask.sum().item()) + candidates.append((scale_sq, name, row, count)) + + candidates.sort(key=lambda x: x[0]) + + batch_size = max(1, len(candidates) // 20) + for i in range(0, len(candidates), batch_size): + batch = candidates[i:i + batch_size] + for _, name, row, _ in batch: + q = result[name + ".q"] + mask = (q[row].abs() == 1) + q[row][mask] = 0 + + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + return result, len(blob) + + +def main() -> None: + + try: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + args.validate() + + # ── Distributed + CUDA ── + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + logger.info(f"Log file: {logfile}") + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + logger.info(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + + # ── Tokenizer + Validation ── + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + # ── Model ── + base_model = YoctoGPT( + vocab_size=args.vocab_size, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_unique_layers=args.num_unique_layers, + num_recurrences=args.num_recurrences, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + seeking_gain_init=args.seeking_gain_init, + rope_fraction=args.rope_fraction, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + int5_layers=args.int5_layers, + ).to(device).bfloat16() + + base_model.unified_bank.data = base_model.unified_bank.data.float() + base_model.output_bank.data = base_model.output_bank.data.float() + base_model.fc_bank.data = base_model.fc_bank.data.float() + base_model.proj_bank.data = base_model.proj_bank.data.float() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + + if master_process: + log_architecture(base_model, args) + + try: + _test_mod = torch.compile(lambda q, k, v: _flash_attn_func(q, k, v, causal=True), dynamic=False) + _tq = torch.randn(1, 8, 1, 48, dtype=torch.bfloat16, device=device) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + _test_mod(_tq, _tq, _tq) + log0("torch.compile + FA3: COMPATIBLE") + compiled_model = torch.compile(base_model, dynamic=False) + model = compiled_model + except Exception as e: + log0(f"torch.compile + FA3: INCOMPATIBLE ({type(e).__name__}), running uncompiled") + model = base_model + + log0("attention_backend:fa3") + + # ── Optimizer: banks → Muon, rest → Adam/AdamW ── + matrix_params = [ + base_model.unified_bank, base_model.output_bank, + base_model.fc_bank, base_model.proj_bank, + ] + + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_param_groups.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_param_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + + replicated_params = [base_model.tok_emb.weight] + scalar_params + if base_model.ve_shared is not None: + replicated_params.append(base_model.ve_shared.embed.weight) + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + replicated_params.append(base_model.lm_head.weight) + if base_model.bigram is not None: + bigram_params = list(base_model.bigram.parameters()) + optimizer_bigram = torch.optim.AdamW([{"params": bigram_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers.append(optimizer_bigram) + replicated_params.extend(bigram_params) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} effective_depth:{args.num_effective_layers}") + if base_model.int5_layer_set: + log0(f"mixed_precision: int5_layers={sorted(base_model.int5_layer_set)} int6_layers={sorted(set(range(args.num_unique_layers)) - base_model.int5_layer_set)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + + # ── Data loader + warmup ── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + # ── EMA + SWA shadow weights ── + ema_state = None + swa_params = None + swa_count = 0 + if args.ema_decay > 0: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"EMA enabled: decay={args.ema_decay}") + if args.swa_every > 0: + swa_params = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"SWA enabled: every {args.swa_every} steps when lr_scale < {args.swa_threshold}") + + def update_ema_swa(step, lr_scale): + nonlocal swa_count + with torch.no_grad(): + if ema_state is not None: + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if swa_params is not None and step > 0 and step % args.swa_every == 0: + if lr_scale < args.swa_threshold: + if swa_count == 0: + for name, t in base_model.state_dict().items(): + swa_params[name].copy_(t.detach().cpu()) + swa_count = 1 + log0(f"SWA started at step {step} (lr_scale={lr_scale:.4f})") + else: + for name, t in base_model.state_dict().items(): + swa_params[name] += t.detach().cpu() + swa_count += 1 + + def get_best_weights(): + """Return best averaged weights. EMA preferred (per PR#401).""" + if ema_state is not None: + log0(f"Using EMA weights (decay={args.ema_decay})") + current_state = base_model.state_dict() + return {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + if swa_params is not None and swa_count >= 2: + log0(f"Using SWA weights ({swa_count} checkpoints)") + current_state = base_model.state_dict() + return {name: (t / swa_count).to(dtype=current_state[name].dtype) + for name, t in swa_params.items()} + return None + + def lr_mul(step, elapsed_ms): + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + wl = model(x, y) + (wl * grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if ws + 1 == args.warmup_steps or (ws + 1) % 10 == 0: + log0(f"warmup_step:{ws+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ── + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # ── QAT activation check ── + global _qat_active, _qat_bits + if args.qat_bits > 0 and not _qat_active: + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac = elapsed_ms / max_wallclock_ms + else: + frac = step / max(args.iterations, 1) + if frac >= args.qat_start_fraction: + _qat_active = True + _qat_bits = args.qat_bits + log0(f"QAT enabled: int{args.qat_bits} at step {step} (fraction={frac:.2f})") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + if opt is not optimizer_muon: + opt.step() + optimizer_muon.step() + + update_ema_swa(step, scale) + zero_grad_all() + + step += 1 + approx_time = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_time:.0f}ms step_avg:{approx_time / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_time >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rc = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rc, op=dist.ReduceOp.MAX) + reached_cap = bool(rc.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ── Load best averaged weights (EMA > SWA > raw) ── + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + + # ── Serialization ── + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Raw model: {model_bytes} bytes, code: {code_bytes} bytes") + + # ── Mixed int6/int8 quantization + roundtrip (if QAT was used) ── + if args.qat_bits == 6: + if master_process: + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu"), strict=True) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_unique_layers) + int5_set = set(int(x) for x in args.int5_layers.split(",") if x.strip()) + mixed_result, mixed_meta = quantize_state_dict_mixed(unbanked_sd, int5_layers=int5_set) + code_bytes = len(code.encode("utf-8")) + mixed_result, _ = prune_to_fit(mixed_result, mixed_meta, code_bytes, + target_bytes=16_000_000, compress=args.compression) + mixed_buf = io.BytesIO() + torch.save({"w": mixed_result, "m": mixed_meta}, mixed_buf) + mixed_raw = mixed_buf.getvalue() + if args.compression == "lzma": + mixed_blob = lzma.compress(mixed_raw, preset=6) + mixed_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_blob = zstd_mod.ZstdCompressor(level=22).compress(mixed_raw) + mixed_label = "zstd-22" + except ImportError: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + else: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + if master_process: + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + mixed_bytes = os.path.getsize("final_model.mixed.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"mixed_int6_int8+{mixed_label}: {mixed_bytes} bytes, total: {mixed_bytes + code_bytes} bytes") + if mixed_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {mixed_bytes + code_bytes} > 16,000,000") + else: + log0(f"FITS: {mixed_bytes + code_bytes} <= 16,000,000") + if distributed: + dist.barrier() + with open("final_model.mixed.ptz", "rb") as f: + mixed_qblob = f.read() + if args.compression == "lzma": + mixed_decompressed = lzma.decompress(mixed_qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_decompressed = zstd_mod.ZstdDecompressor().decompress(mixed_qblob) + except ImportError: + mixed_decompressed = zlib.decompress(mixed_qblob) + else: + mixed_decompressed = zlib.decompress(mixed_qblob) + quant_state = torch.load(io.BytesIO(mixed_decompressed), map_location="cpu") + deq_unbanked = dequantize_state_dict_mixed(quant_state["w"], quant_state["m"], unbanked_sd) + deq_sd = _rebank_state_dict(deq_unbanked, args.num_unique_layers, sd_cpu) + base_model.load_state_dict(deq_sd, strict=True) + torch.cuda.synchronize() + qm_val_loss, qm_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_mixed_{mixed_label}_roundtrip val_loss:{qm_val_loss:.4f} val_bpb:{qm_val_bpb:.4f}") + log0(f"final_mixed_{mixed_label}_roundtrip_exact val_loss:{qm_val_loss:.8f} val_bpb:{qm_val_bpb:.8f}") + + # ── Legal Score-First TTT eval ── + if args.legal_ttt_enabled: + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + torch.cuda.synchronize() + t_legal = time.perf_counter() + legal_loss, legal_bpb = eval_val_legal_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=log0) + log0(f"final_legal_ttt val_loss:{legal_loss:.4f} val_bpb:{legal_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_legal):.0f}ms") + + if distributed: + dist.destroy_process_group() + + except Exception: + logger.error(f"FATAL ERROR:\n{traceback.format_exc()}") + raise + +if __name__ == "__main__": + main() +==================================================================================================== +torch.compile + FA3: COMPATIBLE +attention_backend:fa3 +model_params:23209295 effective_depth:11 +world_size:8 grad_accum_steps:1 +EMA enabled: decay=0.997 +SWA enabled: every 50 steps when lr_scale < 0.2 +warmup_step:10/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9291 val_bpb:4.1038 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9301 train_time:89ms step_avg:89.15ms +step:2/20000 train_loss:6.8909 train_time:104ms step_avg:52.21ms +step:3/20000 train_loss:6.7813 train_time:151ms step_avg:50.17ms +step:4/20000 train_loss:6.5785 train_time:198ms step_avg:49.38ms +step:5/20000 train_loss:6.2913 train_time:245ms step_avg:48.93ms +step:6/20000 train_loss:6.1285 train_time:292ms step_avg:48.68ms +step:7/20000 train_loss:5.8954 train_time:339ms step_avg:48.47ms +step:8/20000 train_loss:5.8447 train_time:387ms step_avg:48.39ms +step:9/20000 train_loss:5.7836 train_time:435ms step_avg:48.29ms +step:10/20000 train_loss:5.7433 train_time:482ms step_avg:48.25ms +step:500/20000 train_loss:2.4790 train_time:24112ms step_avg:48.22ms +step:1000/20000 train_loss:2.3614 train_time:48394ms step_avg:48.39ms +step:1500/20000 train_loss:2.2226 train_time:72754ms step_avg:48.50ms +QAT enabled: int6 at step 1851 (fraction=0.15) +step:2000/20000 train_loss:2.2113 train_time:105651ms step_avg:52.83ms +step:2500/20000 train_loss:2.2053 train_time:130068ms step_avg:52.03ms +step:3000/20000 train_loss:3.1783 train_time:154488ms step_avg:51.50ms +step:3000/20000 val_loss:2.1700 val_bpb:1.2852 train_time:154524ms step_avg:51.51ms +step:3500/20000 train_loss:2.3509 train_time:178970ms step_avg:51.13ms +step:4000/20000 train_loss:2.2542 train_time:203410ms step_avg:50.85ms +step:4500/20000 train_loss:1.8542 train_time:227894ms step_avg:50.64ms +step:5000/20000 train_loss:2.2167 train_time:252386ms step_avg:50.48ms +step:5500/20000 train_loss:2.2037 train_time:276826ms step_avg:50.33ms +step:6000/20000 train_loss:2.0930 train_time:301306ms step_avg:50.22ms +step:6000/20000 val_loss:2.1268 val_bpb:1.2596 train_time:301342ms step_avg:50.22ms +step:6500/20000 train_loss:2.0909 train_time:325793ms step_avg:50.12ms +step:7000/20000 train_loss:2.0839 train_time:350236ms step_avg:50.03ms +step:7500/20000 train_loss:2.1014 train_time:374729ms step_avg:49.96ms +step:8000/20000 train_loss:2.0509 train_time:399176ms step_avg:49.90ms +step:8500/20000 train_loss:2.2060 train_time:423676ms step_avg:49.84ms +step:9000/20000 train_loss:2.0885 train_time:448153ms step_avg:49.79ms +step:9000/20000 val_loss:2.0969 val_bpb:1.2419 train_time:448189ms step_avg:49.80ms +step:9500/20000 train_loss:2.0181 train_time:472597ms step_avg:49.75ms +step:10000/20000 train_loss:1.9749 train_time:497087ms step_avg:49.71ms +step:10500/20000 train_loss:1.9809 train_time:521568ms step_avg:49.67ms +step:11000/20000 train_loss:2.0080 train_time:545998ms step_avg:49.64ms +SWA started at step 11400 (lr_scale=0.1980) +step:11500/20000 train_loss:1.8789 train_time:570664ms step_avg:49.62ms +step:12000/20000 train_loss:1.9447 train_time:595601ms step_avg:49.63ms +step:12000/20000 val_loss:1.9673 val_bpb:1.1651 train_time:595637ms step_avg:49.64ms +step:12088/20000 val_loss:1.9654 val_bpb:1.1640 train_time:600056ms step_avg:49.64ms +stopping_early: wallclock_cap train_time:600056ms step:12088/20000 +peak memory: 12569 MiB +Using EMA weights (decay=0.997) +Raw model: 91514419 bytes, code: 75347 bytes +mixed_int6_int8+lzma-6: 15916340 bytes, total: 15991687 bytes +FITS: 15991687 <= 16,000,000 +final_mixed_lzma-6_roundtrip val_loss:1.9666 val_bpb:1.1647 +final_mixed_lzma-6_roundtrip_exact val_loss:1.96656704 val_bpb:1.16471177 +Using EMA weights (decay=0.997) +legal_ttt:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 lr=0.002 epochs=3 freeze_blocks=0 +legal_ttt:params unfrozen=23209295 frozen=0 + legal_ttt_chunk [1/1893] bpb=1.170985 time=0.4s + legal_ttt_chunk [11/1893] bpb=1.162486 time=2.5s + legal_ttt_chunk [21/1893] bpb=1.149097 time=4.7s + legal_ttt_chunk [31/1893] bpb=1.148442 time=6.8s + legal_ttt_chunk [41/1893] bpb=1.135569 time=9.0s + legal_ttt_chunk [51/1893] bpb=1.130076 time=11.2s + legal_ttt_chunk [61/1893] bpb=1.136868 time=13.3s + legal_ttt_chunk [71/1893] bpb=1.135863 time=15.5s + legal_ttt_chunk [81/1893] bpb=1.135409 time=17.6s + legal_ttt_chunk [91/1893] bpb=1.136404 time=19.8s + legal_ttt_chunk [101/1893] bpb=1.140255 time=22.0s + legal_ttt_chunk [111/1893] bpb=1.142536 time=24.1s + legal_ttt_chunk [121/1893] bpb=1.136087 time=26.3s + legal_ttt_chunk [131/1893] bpb=1.136380 time=28.5s + legal_ttt_chunk [141/1893] bpb=1.142068 time=30.7s + legal_ttt_chunk [151/1893] bpb=1.144141 time=32.9s + legal_ttt_chunk [161/1893] bpb=1.144054 time=35.0s + legal_ttt_chunk [171/1893] bpb=1.148608 time=37.2s + legal_ttt_chunk [181/1893] bpb=1.150775 time=39.3s + legal_ttt_chunk [191/1893] bpb=1.158164 time=41.5s + legal_ttt_chunk [201/1893] bpb=1.157240 time=43.7s + legal_ttt_chunk [211/1893] bpb=1.154882 time=45.9s + legal_ttt_chunk [221/1893] bpb=1.156393 time=48.0s + legal_ttt_chunk [231/1893] bpb=1.154893 time=50.2s + legal_ttt_chunk [241/1893] bpb=1.155349 time=52.3s + legal_ttt_chunk [251/1893] bpb=1.154864 time=54.5s + legal_ttt_chunk [261/1893] bpb=1.151812 time=56.6s + legal_ttt_chunk [271/1893] bpb=1.150679 time=58.8s + legal_ttt_chunk [281/1893] bpb=1.151907 time=60.9s + legal_ttt_chunk [291/1893] bpb=1.153669 time=63.1s + legal_ttt_chunk [301/1893] bpb=1.154392 time=65.2s + legal_ttt_chunk [311/1893] bpb=1.156514 time=67.4s + legal_ttt_chunk [321/1893] bpb=1.158533 time=69.5s + legal_ttt_chunk [331/1893] bpb=1.158551 time=71.7s + legal_ttt_chunk [341/1893] bpb=1.157548 time=73.8s + legal_ttt_chunk [351/1893] bpb=1.159807 time=76.0s + legal_ttt_chunk [361/1893] bpb=1.160187 time=78.1s + legal_ttt_chunk [371/1893] bpb=1.159513 time=80.3s + legal_ttt_chunk [381/1893] bpb=1.159610 time=82.4s + legal_ttt_chunk [391/1893] bpb=1.159436 time=84.6s + legal_ttt_chunk [401/1893] bpb=1.157346 time=86.7s + legal_ttt_chunk [411/1893] bpb=1.156282 time=88.9s + legal_ttt_chunk [421/1893] bpb=1.155294 time=91.0s + legal_ttt_chunk [431/1893] bpb=1.155104 time=93.2s + legal_ttt_chunk [441/1893] bpb=1.155399 time=95.4s + legal_ttt_chunk [451/1893] bpb=1.155675 time=97.5s + legal_ttt_chunk [461/1893] bpb=1.154516 time=99.7s + legal_ttt_chunk [471/1893] bpb=1.155079 time=101.8s + legal_ttt_chunk [481/1893] bpb=1.154746 time=104.0s + legal_ttt_chunk [491/1893] bpb=1.153714 time=106.1s + legal_ttt_chunk [501/1893] bpb=1.153270 time=108.3s + legal_ttt_chunk [511/1893] bpb=1.152617 time=110.4s + legal_ttt_chunk [521/1893] bpb=1.150500 time=112.6s + legal_ttt_chunk [531/1893] bpb=1.151635 time=114.7s + legal_ttt_chunk [541/1893] bpb=1.151990 time=116.9s + legal_ttt_chunk [551/1893] bpb=1.150899 time=119.0s + legal_ttt_chunk [561/1893] bpb=1.151392 time=121.2s + legal_ttt_chunk [571/1893] bpb=1.150305 time=123.3s + legal_ttt_chunk [581/1893] bpb=1.149546 time=125.5s + legal_ttt_chunk [591/1893] bpb=1.148876 time=127.6s + legal_ttt_chunk [601/1893] bpb=1.149425 time=129.8s + legal_ttt_chunk [611/1893] bpb=1.149296 time=131.9s + legal_ttt_chunk [621/1893] bpb=1.149095 time=134.1s + legal_ttt_chunk [631/1893] bpb=1.149807 time=136.2s + legal_ttt_chunk [641/1893] bpb=1.149609 time=138.4s + legal_ttt_chunk [651/1893] bpb=1.149763 time=140.5s + legal_ttt_chunk [661/1893] bpb=1.149272 time=142.7s + legal_ttt_chunk [671/1893] bpb=1.149650 time=144.8s + legal_ttt_chunk [681/1893] bpb=1.150314 time=147.0s + legal_ttt_chunk [691/1893] bpb=1.151343 time=149.1s + legal_ttt_chunk [701/1893] bpb=1.150779 time=151.3s + legal_ttt_chunk [711/1893] bpb=1.150805 time=153.4s + legal_ttt_chunk [721/1893] bpb=1.150484 time=155.6s + legal_ttt_chunk [731/1893] bpb=1.150522 time=157.7s + legal_ttt_chunk [741/1893] bpb=1.150576 time=159.9s + legal_ttt_chunk [751/1893] bpb=1.150443 time=162.1s + legal_ttt_chunk [761/1893] bpb=1.150335 time=164.2s + legal_ttt_chunk [771/1893] bpb=1.150048 time=166.4s + legal_ttt_chunk [781/1893] bpb=1.150804 time=168.5s + legal_ttt_chunk [791/1893] bpb=1.150410 time=170.7s + legal_ttt_chunk [801/1893] bpb=1.150667 time=172.8s + legal_ttt_chunk [811/1893] bpb=1.150443 time=175.0s + legal_ttt_chunk [821/1893] bpb=1.150206 time=177.1s + legal_ttt_chunk [831/1893] bpb=1.150040 time=179.3s + legal_ttt_chunk [841/1893] bpb=1.149365 time=181.4s + legal_ttt_chunk [851/1893] bpb=1.149121 time=183.6s + legal_ttt_chunk [861/1893] bpb=1.148861 time=185.7s + legal_ttt_chunk [871/1893] bpb=1.149100 time=187.9s + legal_ttt_chunk [881/1893] bpb=1.149245 time=190.1s + legal_ttt_chunk [891/1893] bpb=1.148811 time=192.2s + legal_ttt_chunk [901/1893] bpb=1.148548 time=194.4s + legal_ttt_chunk [911/1893] bpb=1.148671 time=196.5s + legal_ttt_chunk [921/1893] bpb=1.149155 time=198.7s + legal_ttt_chunk [931/1893] bpb=1.149167 time=200.8s + legal_ttt_chunk [941/1893] bpb=1.148819 time=203.0s + legal_ttt_chunk [951/1893] bpb=1.149223 time=205.1s + legal_ttt_chunk [961/1893] bpb=1.149278 time=207.3s + legal_ttt_chunk [971/1893] bpb=1.150125 time=209.5s + legal_ttt_chunk [981/1893] bpb=1.150190 time=211.6s + legal_ttt_chunk [991/1893] bpb=1.150217 time=213.8s + legal_ttt_chunk [1001/1893] bpb=1.150178 time=215.9s + legal_ttt_chunk [1011/1893] bpb=1.149965 time=218.1s + legal_ttt_chunk [1021/1893] bpb=1.150300 time=220.2s + legal_ttt_chunk [1031/1893] bpb=1.150753 time=222.4s + legal_ttt_chunk [1041/1893] bpb=1.150395 time=224.5s + legal_ttt_chunk [1051/1893] bpb=1.150137 time=226.7s + legal_ttt_chunk [1061/1893] bpb=1.150165 time=228.8s + legal_ttt_chunk [1071/1893] bpb=1.150796 time=231.0s + legal_ttt_chunk [1081/1893] bpb=1.151032 time=233.2s + legal_ttt_chunk [1091/1893] bpb=1.151775 time=235.3s + legal_ttt_chunk [1101/1893] bpb=1.151787 time=237.5s + legal_ttt_chunk [1111/1893] bpb=1.151667 time=239.6s + legal_ttt_chunk [1121/1893] bpb=1.151463 time=241.8s + legal_ttt_chunk [1131/1893] bpb=1.151382 time=244.0s + legal_ttt_chunk [1141/1893] bpb=1.151073 time=246.1s + legal_ttt_chunk [1151/1893] bpb=1.151100 time=248.3s + legal_ttt_chunk [1161/1893] bpb=1.150765 time=250.4s + legal_ttt_chunk [1171/1893] bpb=1.151070 time=252.6s + legal_ttt_chunk [1181/1893] bpb=1.150328 time=254.7s + legal_ttt_chunk [1191/1893] bpb=1.150217 time=256.9s + legal_ttt_chunk [1201/1893] bpb=1.150623 time=259.1s + legal_ttt_chunk [1211/1893] bpb=1.150150 time=261.2s + legal_ttt_chunk [1221/1893] bpb=1.149838 time=263.4s + legal_ttt_chunk [1231/1893] bpb=1.149561 time=265.6s + legal_ttt_chunk [1241/1893] bpb=1.149222 time=267.7s + legal_ttt_chunk [1251/1893] bpb=1.148612 time=269.9s + legal_ttt_chunk [1261/1893] bpb=1.148597 time=272.0s + legal_ttt_chunk [1271/1893] bpb=1.148218 time=274.2s + legal_ttt_chunk [1281/1893] bpb=1.148034 time=276.3s + legal_ttt_chunk [1291/1893] bpb=1.147815 time=278.5s + legal_ttt_chunk [1301/1893] bpb=1.147215 time=280.7s + legal_ttt_chunk [1311/1893] bpb=1.146845 time=282.8s + legal_ttt_chunk [1321/1893] bpb=1.146522 time=285.0s + legal_ttt_chunk [1331/1893] bpb=1.146454 time=287.1s + legal_ttt_chunk [1341/1893] bpb=1.146330 time=289.3s + legal_ttt_chunk [1351/1893] bpb=1.146264 time=291.5s + legal_ttt_chunk [1361/1893] bpb=1.146313 time=293.6s + legal_ttt_chunk [1371/1893] bpb=1.146193 time=295.8s + legal_ttt_chunk [1381/1893] bpb=1.146195 time=297.9s + legal_ttt_chunk [1391/1893] bpb=1.145789 time=300.1s + legal_ttt_chunk [1401/1893] bpb=1.145761 time=302.2s + legal_ttt_chunk [1411/1893] bpb=1.145890 time=304.4s + legal_ttt_chunk [1421/1893] bpb=1.146138 time=306.6s + legal_ttt_chunk [1431/1893] bpb=1.145853 time=308.7s + legal_ttt_chunk [1441/1893] bpb=1.146377 time=310.9s + legal_ttt_chunk [1451/1893] bpb=1.146714 time=313.0s + legal_ttt_chunk [1461/1893] bpb=1.146269 time=315.2s + legal_ttt_chunk [1471/1893] bpb=1.147306 time=317.3s + legal_ttt_chunk [1481/1893] bpb=1.146857 time=319.5s + legal_ttt_chunk [1491/1893] bpb=1.146675 time=321.7s + legal_ttt_chunk [1501/1893] bpb=1.146611 time=323.8s + legal_ttt_chunk [1511/1893] bpb=1.146618 time=326.0s + legal_ttt_chunk [1521/1893] bpb=1.146672 time=328.1s + legal_ttt_chunk [1531/1893] bpb=1.146161 time=330.3s + legal_ttt_chunk [1541/1893] bpb=1.146014 time=332.5s + legal_ttt_chunk [1551/1893] bpb=1.146312 time=334.6s + legal_ttt_chunk [1561/1893] bpb=1.146304 time=336.8s + legal_ttt_chunk [1571/1893] bpb=1.146153 time=338.9s + legal_ttt_chunk [1581/1893] bpb=1.146303 time=341.1s + legal_ttt_chunk [1591/1893] bpb=1.146159 time=343.2s + legal_ttt_chunk [1601/1893] bpb=1.146328 time=345.4s + legal_ttt_chunk [1611/1893] bpb=1.146260 time=347.6s + legal_ttt_chunk [1621/1893] bpb=1.145868 time=349.7s + legal_ttt_chunk [1631/1893] bpb=1.146182 time=351.9s + legal_ttt_chunk [1641/1893] bpb=1.146205 time=354.0s + legal_ttt_chunk [1651/1893] bpb=1.146154 time=356.2s + legal_ttt_chunk [1661/1893] bpb=1.146040 time=358.4s + legal_ttt_chunk [1671/1893] bpb=1.146508 time=360.5s + legal_ttt_chunk [1681/1893] bpb=1.146663 time=362.7s + legal_ttt_chunk [1691/1893] bpb=1.146496 time=364.8s + legal_ttt_chunk [1701/1893] bpb=1.146625 time=367.0s + legal_ttt_chunk [1711/1893] bpb=1.146608 time=369.1s + legal_ttt_chunk [1721/1893] bpb=1.146593 time=371.3s + legal_ttt_chunk [1731/1893] bpb=1.146455 time=373.5s + legal_ttt_chunk [1741/1893] bpb=1.146270 time=375.6s + legal_ttt_chunk [1751/1893] bpb=1.146098 time=377.8s + legal_ttt_chunk [1761/1893] bpb=1.146243 time=379.9s + legal_ttt_chunk [1771/1893] bpb=1.146143 time=382.1s + legal_ttt_chunk [1781/1893] bpb=1.146166 time=384.2s + legal_ttt_chunk [1791/1893] bpb=1.145764 time=386.4s + legal_ttt_chunk [1801/1893] bpb=1.145639 time=388.6s + legal_ttt_chunk [1811/1893] bpb=1.145525 time=390.7s + legal_ttt_chunk [1821/1893] bpb=1.145579 time=392.9s + legal_ttt_chunk [1831/1893] bpb=1.144993 time=395.1s + legal_ttt_chunk [1841/1893] bpb=1.145124 time=397.2s + legal_ttt_chunk [1851/1893] bpb=1.144920 time=399.4s + legal_ttt_chunk [1861/1893] bpb=1.144544 time=401.5s + legal_ttt_chunk [1871/1893] bpb=1.144517 time=403.7s + legal_ttt_chunk [1881/1893] bpb=1.144085 time=405.8s + legal_ttt_chunk [1891/1893] bpb=1.143853 time=408.0s + legal_ttt_chunk [1893/1893] bpb=1.143898 time=408.3s +legal_ttt:done val_loss=1.927514 val_bpb=1.141585 elapsed=408.3s +final_legal_ttt val_loss:1.9275 val_bpb:1.1416 eval_time:408770ms diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed2025.log b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed2025.log new file mode 100644 index 0000000000..299ea54a45 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed2025.log @@ -0,0 +1,1899 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import traceback +import uuid +import zlib +import lzma +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_func +except ImportError: + raise ImportError( + "Flash Attention 3 (Hopper) is required. Install with:\n" + " pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280\n" + "Or see requirements.txt for details." + ) + + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("yocto-golf") + +def log_architecture(model, args): + n = sum(p.numel() for p in model.parameters()) + logger.info(f"YOCTO d={args.model_dim} K={args.num_unique_layers} heads={args.num_heads} params={n:,}") + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # ── Yocto architecture ── + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 552)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 10)) + num_recurrences = int(os.environ.get("NUM_RECURRENCES", 1)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + seeking_gain_init = float(os.environ.get("SEEKING_GAIN_INIT", 1.5)) + rope_fraction = float(os.environ.get("ROPE_FRACTION", 1.0)) # 1.0 = full RoPE, 0.5 = half partial RoPE + + # ── Optimizer ── + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # ── LR warmup (actual learning rate ramp, separate from compile warmup) ── + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 100)) + + # ── EMA ── + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # 0 = disabled, 0.997 = SOTA setting + + # ── SWA (Stochastic Weight Averaging) ── + swa_every = int(os.environ.get("SWA_EVERY", 50)) # 0 = disabled, 50 = SOTA setting + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) # only SWA when lr_scale < this + + # ── Compression ── + compression = os.environ.get("COMPRESSION", "lzma") # "zlib", "zstd", or "lzma" + + # ── QAT (Quantization-Aware Training) ── + qat_bits = int(os.environ.get("QAT_BITS", 6)) # 0 = disabled, 6 = int6 QAT + qat_start_fraction = float(os.environ.get("QAT_START_FRACTION", 0.15)) # when to start QAT + + # ── Mixed precision quantization ── + int5_layers = os.environ.get("INT5_LAYERS", "") # e.g. "2,3,4,5,6,7,8" + + sliding_window_stride = int(os.environ.get("SLIDING_WINDOW_STRIDE", 64)) + + # ── LN Scale ── + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # 1/sqrt(layer_idx+1) on norm outputs + + # ── Value Embedding (VE128) ── + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "8,9") # last 2 of 10 layers + + # ── TTT LoRA ── + + # ── Legal Score-First TTT ── + legal_ttt_enabled = bool(int(os.environ.get("LEGAL_TTT_ENABLED", "1"))) + legal_ttt_lr = float(os.environ.get("LEGAL_TTT_LR", 0.002)) + legal_ttt_epochs = int(os.environ.get("LEGAL_TTT_EPOCHS", 3)) + legal_ttt_chunk_tokens = int(os.environ.get("LEGAL_TTT_CHUNK_TOKENS", 32768)) + legal_ttt_freeze_blocks = int(os.environ.get("LEGAL_TTT_FREEZE_BLOCKS", 0)) + legal_ttt_momentum = float(os.environ.get("LEGAL_TTT_MOMENTUM", 0.9)) + legal_ttt_batch_seqs = int(os.environ.get("LEGAL_TTT_BATCH_SEQS", 32)) + legal_ttt_grad_clip = float(os.environ.get("LEGAL_TTT_GRAD_CLIP", 1.0)) + + @property + def num_effective_layers(self) -> int: + return self.num_unique_layers * self.num_recurrences + + def validate(self) -> None: + """Check all divisibility constraints.""" + d = self.model_dim + assert d % 3 == 0, f"model_dim={d} must be divisible by 3 for unified attention split" + comp = d // 3 + assert comp % self.num_heads == 0, ( + f"component_dim={comp} (model_dim/3) must be divisible by num_heads={self.num_heads}" + ) + head_dim = comp // self.num_heads + assert head_dim % 2 == 0, f"head_dim={head_dim} must be even for RoPE" + assert head_dim >= 16, f"head_dim={head_dim} must be >= 16 for useful RoPE (got {head_dim})" + assert self.logit_softcap > 0, f"logit_softcap must be positive" + logger.info(f"Architecture constraints validated: d={d}, comp={comp}, heads={self.num_heads}, " + f"head_dim={head_dim}, RoPE_pairs={head_dim//2}") + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, batched NS5, all-gather.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "mlp_scale", "resid_mix", "skip_weight", "seeking_gain", "smear", "ve_layer_scales", "ve_shared.scale") +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 99.99984 / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_float_tensor_int6(t: Tensor): + return quantize_float_tensor_intN(t, max_val=31) + +def quantize_float_tensor_intN(t: Tensor, max_val: int = 31): + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_err = None, None, float('inf') + + for pct in GPTQ_CLIP_PERCENTILES: + if pct >= 1.0: + clip_abs = t32.abs().amax(dim=1).clamp_min(1e-8) + else: + clip_abs = torch.quantile(t32.abs(), pct, dim=1).clamp_min(1e-8) + scale = (clip_abs / max_val).clamp_min(1e-8).to(torch.float16) + clipped = t32.clamp(-clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -max_val, max_val).to(torch.int8) + recon = q.float() * scale.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_scale, best_err = q, scale, err + + return best_q.contiguous(), best_scale.contiguous() + abs_max = t32.abs().max().clamp_min(1e-8).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -max_val, max_val).to(torch.int8) + return q, scale + +# ── Unbank/rebank for quantization ── + +def _unbank_state_dict(sd, num_layers): + out = {} + for name, tensor in sd.items(): + if name == "unified_bank": + for i in range(num_layers): + w = tensor[i] # [d, d] + d = w.shape[0] + comp = d // 3 + out[f"blocks.{i}.attn.W_seeking.weight"] = w[:comp, :] + out[f"blocks.{i}.attn.W_offering.weight"] = w[comp:2*comp, :] + out[f"blocks.{i}.attn.W_content.weight"] = w[2*comp:, :] + elif name == "output_bank": + for i in range(num_layers): + out[f"blocks.{i}.attn.W_output.weight"] = tensor[i] + elif name == "fc_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "proj_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd, num_layers, template_sd): + out = {} + consumed = set() + + unified_slices = [] + for i in range(num_layers): + sk = f"blocks.{i}.attn.W_seeking.weight" + ok = f"blocks.{i}.attn.W_offering.weight" + ck = f"blocks.{i}.attn.W_content.weight" + unified_slices.append(torch.cat([sd[sk], sd[ok], sd[ck]], dim=0)) + consumed.update([sk, ok, ck]) + out["unified_bank"] = torch.stack(unified_slices).to(dtype=template_sd["unified_bank"].dtype) + + for bank_name, key_template in [ + ("output_bank", "blocks.{i}.attn.W_output.weight"), + ("fc_bank", "blocks.{i}.mlp.fc.weight"), + ("proj_bank", "blocks.{i}.mlp.proj.weight"), + ]: + slices = [] + for i in range(num_layers): + k = key_template.format(i=i) + slices.append(sd[k]) + consumed.add(k) + out[bank_name] = torch.stack(slices).to(dtype=template_sd[bank_name].dtype) + + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +INT8_EMBED_PATTERNS = ("tok_emb.", "ve_shared.embed.") + +def quantize_state_dict_mixed(state_dict, int5_layers=None): + if int5_layers is None: + int5_layers = set() + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + result[name] = t.float().contiguous() + meta[name] = "passthrough_ctrl" + else: + result[name] = t.to(torch.float16).contiguous() + meta[name] = "passthrough" + continue + is_embed = any(p in name for p in INT8_EMBED_PATTERNS) + if is_embed: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + layer_idx = -1 + if "blocks." in name: + try: + layer_idx = int(name.split("blocks.")[1].split(".")[0]) + except (ValueError, IndexError): + pass + if layer_idx in int5_layers: + q, s = quantize_float_tensor_intN(t, max_val=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = quantize_float_tensor_int6(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + return result, meta + +def dequantize_state_dict_mixed(result, meta, template_sd=None): + """Dequantize flat-key mixed int6/int8 state dict back to float tensors.""" + out = {} + for name, info in meta.items(): + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if template_sd is not None and name in template_sd: + orig_dtype = template_sd[name].dtype + if t.dtype != orig_dtype: + t = t.to(orig_dtype) + out[name] = t + continue + q = result[name + ".q"] + s = result[name + ".scale"] + if s.ndim > 0: + deq = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + deq = q.float() * float(s.item()) + target_dtype = torch.bfloat16 + if template_sd is not None and name in template_sd: + target_dtype = template_sd[name].dtype + out[name] = deq.to(target_dtype).contiguous() + return out + + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_mlp = False # kept for compatibility + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.training and _qat_active and w.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + w = _fake_quantize(w, _qat_bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + +# ── QAT globals (set during training) ── +_qat_active = False +_qat_bits = 6 + +def _fake_quantize(w: Tensor, bits: int) -> Tensor: + max_val = (1 << (bits - 1)) - 1 # e.g. int6: max_val = 31 + with torch.no_grad(): + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + scale = abs_max / max_val + w_q = (w / scale).round().clamp(-max_val, max_val) * scale + return w + (w_q - w).detach() + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, target_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, target_dim, bias=False) if ve_dim != target_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class UnifiedAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0): + super().__init__() + assert dim % 3 == 0, f"dim={dim} must be divisible by 3" + self.dim = dim + self.num_heads = num_heads + self.component_dim = dim // 3 + self.head_dim = self.component_dim // num_heads + assert self.component_dim % num_heads == 0 + + self.rope_dim = int(self.head_dim * rope_fraction) + self.rope_dim = max(self.rope_dim - (self.rope_dim % 2), 2) + self.pass_dim = self.head_dim - self.rope_dim + + self.seeking_gain = nn.Parameter( + torch.full((num_heads,), seeking_gain_init, dtype=torch.float32) + ) + self.rotary = Rotary(self.rope_dim, base=rope_base) + + def forward(self, x: Tensor, unified_w: Tensor, output_w: Tensor, unified_delta=None, v_embed=None) -> Tensor: + bsz, seqlen, _ = x.shape + + unified = F.linear(x, unified_w.to(x.dtype)) + if unified_delta is not None: + unified = unified + unified_delta + + seeking, offering, content = unified.split(self.component_dim, dim=-1) + + if v_embed is not None: + content = content + v_embed + + def to_heads(t): + return t.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + + seeking = to_heads(seeking) + offering = to_heads(offering) + content = to_heads(content) + + seeking = F.rms_norm(seeking, (seeking.size(-1),)) + offering = F.rms_norm(offering, (offering.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, seeking.dtype) + if self.pass_dim > 0: + s_rope, s_pass = seeking[..., :self.rope_dim], seeking[..., self.rope_dim:] + o_rope, o_pass = offering[..., :self.rope_dim], offering[..., self.rope_dim:] + s_rope = apply_rotary_emb(s_rope, cos, sin) + o_rope = apply_rotary_emb(o_rope, cos, sin) + seeking = torch.cat([s_rope, s_pass], dim=-1) + offering = torch.cat([o_rope, o_pass], dim=-1) + else: + seeking = apply_rotary_emb(seeking, cos, sin) + offering = apply_rotary_emb(offering, cos, sin) + + seeking = seeking * self.seeking_gain.to(dtype=seeking.dtype)[None, :, None, None] + + sq = seeking.transpose(1, 2) + of = offering.transpose(1, 2) + ct = content.transpose(1, 2) + dtype = sq.dtype + if dtype not in (torch.float16, torch.bfloat16): + sq, of, ct = sq.to(torch.bfloat16), of.to(torch.bfloat16), ct.to(torch.bfloat16) + hd = sq.size(-1) + pad_n = (8 - hd % 8) % 8 + if pad_n > 0: + sq = F.pad(sq, (0, pad_n)) + of = F.pad(of, (0, pad_n)) + ct = F.pad(ct, (0, pad_n)) + out = _flash_attn_func(sq, of, ct, causal=True) + y = out[0] if isinstance(out, tuple) else out + if pad_n > 0: + y = y[..., :hd] + if y.dtype != dtype: + y = y.to(dtype) + y = y.transpose(1, 2) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, self.component_dim) + return F.linear(y, output_w.to(x.dtype)) + +class SquaredReLUMLP(nn.Module): + """LeakyReLU(0.5)² MLP — weights passed from banks.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + + def forward(self, x: Tensor, fc_w: Tensor, proj_w: Tensor) -> Tensor: + return F.linear( + F.leaky_relu(F.linear(x, fc_w.to(x.dtype)), negative_slope=0.5).square(), + proj_w.to(x.dtype) + ) + +class Block(nn.Module): + """Single transformer block with unified attention + MLP. Weights from banks.""" + def __init__(self, dim: int, num_heads: int, mlp_mult: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = UnifiedAttention(dim, num_heads, rope_base, seeking_gain_init, rope_fraction) + self.mlp = SquaredReLUMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, unified_w: Tensor, output_w: Tensor, + fc_w: Tensor, proj_w: Tensor, unified_delta_fn=None, v_embed=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) * self.ln_scale_factor + ud = unified_delta_fn(n) if unified_delta_fn is not None else None + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(n, unified_w, output_w, ud, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, fc_w, proj_w) + return x + +class YoctoGPT(nn.Module): + def __init__(self, vocab_size: int, model_dim: int, num_heads: int, + num_unique_layers: int, num_recurrences: int, mlp_mult: int, + tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, seeking_gain_init: float, + rope_fraction: float = 1.0, + ln_scale: bool = True, + ve_enabled: bool = True, ve_dim: int = 128, ve_layers: str = "8,9", + int5_layers: str = ""): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_unique_layers = num_unique_layers + self.num_recurrences = num_recurrences + self.int5_layer_set = set(int(x) for x in int5_layers.split(",") if x.strip()) + effective = num_unique_layers * num_recurrences + + comp_dim = model_dim // 3 + mlp_dim = mlp_mult * model_dim + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = None + self.smear = SmearGate(model_dim) + + K = num_unique_layers + self.unified_bank = nn.Parameter(torch.empty(K, model_dim, model_dim)) # W_unified: d→d + self.output_bank = nn.Parameter(torch.empty(K, model_dim, comp_dim)) # W_output: comp→d (F.linear expects [out, in]) + self.fc_bank = nn.Parameter(torch.empty(K, mlp_dim, model_dim)) # MLP fc: d→mlp_dim + self.proj_bank = nn.Parameter(torch.empty(K, model_dim, mlp_dim)) # MLP proj: mlp_dim→d + + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, mlp_mult, rope_base, seeking_gain_init, rope_fraction, + layer_idx=k, ln_scale=ln_scale) + for k in range(num_unique_layers) + ]) + + self.num_encoder_layers = effective // 2 + self.num_decoder_layers = effective - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, comp_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights(tied_embed_init_std) + + def _init_weights(self, std: float) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=std) + K = self.num_unique_layers + proj_scale = 1.0 / math.sqrt(2 * K * self.num_recurrences) + for i in range(K): + nn.init.orthogonal_(self.unified_bank.data[i], gain=1.0) + nn.init.zeros_(self.output_bank.data[i]) + self.output_bank.data[i].mul_(proj_scale) + nn.init.orthogonal_(self.fc_bank.data[i], gain=1.0) + nn.init.zeros_(self.proj_bank.data[i]) + self.proj_bank.data[i].mul_(proj_scale) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _qat_weight(self, w: Tensor, layer_idx: int = -1) -> Tensor: + if self.training and _qat_active: + bits = 5 if layer_idx in self.int5_layer_set else _qat_bits + return _fake_quantize(w, bits) + return w + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ud_fn = lora.unified_loras[k] if (lora and lora.unified_loras is not None) else None + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self._qat_weight(self.unified_bank[k], k), + self._qat_weight(self.output_bank[k], k), + self._qat_weight(self.fc_bank[k], k), + self._qat_weight(self.proj_bank[k], k), + ud_fn, v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self.unified_bank[k], self.output_bank[k], + self.fc_bank[k], self.proj_bank[k], + v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + +def eval_val_legal_ttt(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=print): + seq_len = args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.legal_ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"legal_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lr={args.legal_ttt_lr} epochs={args.legal_ttt_epochs} " + f"freeze_blocks={args.legal_ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(args.legal_ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"legal_ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.legal_ttt_lr, momentum=args.legal_ttt_momentum) + batch_seqs = args.legal_ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.legal_ttt_epochs > 0: + base_model.train() + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.legal_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.legal_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.legal_ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" legal_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"legal_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def prune_to_fit(result, meta, code_bytes, target_bytes=16_000_000, compress="lzma"): + """Selectively zero ±1 quantized values to fit artifact in budget.""" + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + candidates = [] + for name, info in meta.items(): + if isinstance(info, dict) and info.get("type") in ("int6", "int5"): + q = result[name + ".q"] + s = result[name + ".scale"] + for row in range(q.shape[0]): + mask = (q[row].abs() == 1) + if mask.any(): + scale_sq = float(s[row].float() ** 2) if s.ndim > 0 else float(s.float() ** 2) + count = int(mask.sum().item()) + candidates.append((scale_sq, name, row, count)) + + candidates.sort(key=lambda x: x[0]) + + batch_size = max(1, len(candidates) // 20) + for i in range(0, len(candidates), batch_size): + batch = candidates[i:i + batch_size] + for _, name, row, _ in batch: + q = result[name + ".q"] + mask = (q[row].abs() == 1) + q[row][mask] = 0 + + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + return result, len(blob) + + +def main() -> None: + + try: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + args.validate() + + # ── Distributed + CUDA ── + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + logger.info(f"Log file: {logfile}") + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + logger.info(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + + # ── Tokenizer + Validation ── + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + # ── Model ── + base_model = YoctoGPT( + vocab_size=args.vocab_size, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_unique_layers=args.num_unique_layers, + num_recurrences=args.num_recurrences, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + seeking_gain_init=args.seeking_gain_init, + rope_fraction=args.rope_fraction, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + int5_layers=args.int5_layers, + ).to(device).bfloat16() + + base_model.unified_bank.data = base_model.unified_bank.data.float() + base_model.output_bank.data = base_model.output_bank.data.float() + base_model.fc_bank.data = base_model.fc_bank.data.float() + base_model.proj_bank.data = base_model.proj_bank.data.float() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + + if master_process: + log_architecture(base_model, args) + + try: + _test_mod = torch.compile(lambda q, k, v: _flash_attn_func(q, k, v, causal=True), dynamic=False) + _tq = torch.randn(1, 8, 1, 48, dtype=torch.bfloat16, device=device) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + _test_mod(_tq, _tq, _tq) + log0("torch.compile + FA3: COMPATIBLE") + compiled_model = torch.compile(base_model, dynamic=False) + model = compiled_model + except Exception as e: + log0(f"torch.compile + FA3: INCOMPATIBLE ({type(e).__name__}), running uncompiled") + model = base_model + + log0("attention_backend:fa3") + + # ── Optimizer: banks → Muon, rest → Adam/AdamW ── + matrix_params = [ + base_model.unified_bank, base_model.output_bank, + base_model.fc_bank, base_model.proj_bank, + ] + + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_param_groups.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_param_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + + replicated_params = [base_model.tok_emb.weight] + scalar_params + if base_model.ve_shared is not None: + replicated_params.append(base_model.ve_shared.embed.weight) + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + replicated_params.append(base_model.lm_head.weight) + if base_model.bigram is not None: + bigram_params = list(base_model.bigram.parameters()) + optimizer_bigram = torch.optim.AdamW([{"params": bigram_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers.append(optimizer_bigram) + replicated_params.extend(bigram_params) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} effective_depth:{args.num_effective_layers}") + if base_model.int5_layer_set: + log0(f"mixed_precision: int5_layers={sorted(base_model.int5_layer_set)} int6_layers={sorted(set(range(args.num_unique_layers)) - base_model.int5_layer_set)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + + # ── Data loader + warmup ── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + # ── EMA + SWA shadow weights ── + ema_state = None + swa_params = None + swa_count = 0 + if args.ema_decay > 0: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"EMA enabled: decay={args.ema_decay}") + if args.swa_every > 0: + swa_params = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"SWA enabled: every {args.swa_every} steps when lr_scale < {args.swa_threshold}") + + def update_ema_swa(step, lr_scale): + nonlocal swa_count + with torch.no_grad(): + if ema_state is not None: + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if swa_params is not None and step > 0 and step % args.swa_every == 0: + if lr_scale < args.swa_threshold: + if swa_count == 0: + for name, t in base_model.state_dict().items(): + swa_params[name].copy_(t.detach().cpu()) + swa_count = 1 + log0(f"SWA started at step {step} (lr_scale={lr_scale:.4f})") + else: + for name, t in base_model.state_dict().items(): + swa_params[name] += t.detach().cpu() + swa_count += 1 + + def get_best_weights(): + """Return best averaged weights. EMA preferred (per PR#401).""" + if ema_state is not None: + log0(f"Using EMA weights (decay={args.ema_decay})") + current_state = base_model.state_dict() + return {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + if swa_params is not None and swa_count >= 2: + log0(f"Using SWA weights ({swa_count} checkpoints)") + current_state = base_model.state_dict() + return {name: (t / swa_count).to(dtype=current_state[name].dtype) + for name, t in swa_params.items()} + return None + + def lr_mul(step, elapsed_ms): + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + wl = model(x, y) + (wl * grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if ws + 1 == args.warmup_steps or (ws + 1) % 10 == 0: + log0(f"warmup_step:{ws+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ── + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # ── QAT activation check ── + global _qat_active, _qat_bits + if args.qat_bits > 0 and not _qat_active: + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac = elapsed_ms / max_wallclock_ms + else: + frac = step / max(args.iterations, 1) + if frac >= args.qat_start_fraction: + _qat_active = True + _qat_bits = args.qat_bits + log0(f"QAT enabled: int{args.qat_bits} at step {step} (fraction={frac:.2f})") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + if opt is not optimizer_muon: + opt.step() + optimizer_muon.step() + + update_ema_swa(step, scale) + zero_grad_all() + + step += 1 + approx_time = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_time:.0f}ms step_avg:{approx_time / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_time >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rc = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rc, op=dist.ReduceOp.MAX) + reached_cap = bool(rc.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ── Load best averaged weights (EMA > SWA > raw) ── + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + + # ── Serialization ── + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Raw model: {model_bytes} bytes, code: {code_bytes} bytes") + + # ── Mixed int6/int8 quantization + roundtrip (if QAT was used) ── + if args.qat_bits == 6: + if master_process: + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu"), strict=True) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_unique_layers) + int5_set = set(int(x) for x in args.int5_layers.split(",") if x.strip()) + mixed_result, mixed_meta = quantize_state_dict_mixed(unbanked_sd, int5_layers=int5_set) + code_bytes = len(code.encode("utf-8")) + mixed_result, _ = prune_to_fit(mixed_result, mixed_meta, code_bytes, + target_bytes=16_000_000, compress=args.compression) + mixed_buf = io.BytesIO() + torch.save({"w": mixed_result, "m": mixed_meta}, mixed_buf) + mixed_raw = mixed_buf.getvalue() + if args.compression == "lzma": + mixed_blob = lzma.compress(mixed_raw, preset=6) + mixed_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_blob = zstd_mod.ZstdCompressor(level=22).compress(mixed_raw) + mixed_label = "zstd-22" + except ImportError: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + else: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + if master_process: + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + mixed_bytes = os.path.getsize("final_model.mixed.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"mixed_int6_int8+{mixed_label}: {mixed_bytes} bytes, total: {mixed_bytes + code_bytes} bytes") + if mixed_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {mixed_bytes + code_bytes} > 16,000,000") + else: + log0(f"FITS: {mixed_bytes + code_bytes} <= 16,000,000") + if distributed: + dist.barrier() + with open("final_model.mixed.ptz", "rb") as f: + mixed_qblob = f.read() + if args.compression == "lzma": + mixed_decompressed = lzma.decompress(mixed_qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_decompressed = zstd_mod.ZstdDecompressor().decompress(mixed_qblob) + except ImportError: + mixed_decompressed = zlib.decompress(mixed_qblob) + else: + mixed_decompressed = zlib.decompress(mixed_qblob) + quant_state = torch.load(io.BytesIO(mixed_decompressed), map_location="cpu") + deq_unbanked = dequantize_state_dict_mixed(quant_state["w"], quant_state["m"], unbanked_sd) + deq_sd = _rebank_state_dict(deq_unbanked, args.num_unique_layers, sd_cpu) + base_model.load_state_dict(deq_sd, strict=True) + torch.cuda.synchronize() + qm_val_loss, qm_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_mixed_{mixed_label}_roundtrip val_loss:{qm_val_loss:.4f} val_bpb:{qm_val_bpb:.4f}") + log0(f"final_mixed_{mixed_label}_roundtrip_exact val_loss:{qm_val_loss:.8f} val_bpb:{qm_val_bpb:.8f}") + + # ── Legal Score-First TTT eval ── + if args.legal_ttt_enabled: + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + torch.cuda.synchronize() + t_legal = time.perf_counter() + legal_loss, legal_bpb = eval_val_legal_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=log0) + log0(f"final_legal_ttt val_loss:{legal_loss:.4f} val_bpb:{legal_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_legal):.0f}ms") + + if distributed: + dist.destroy_process_group() + + except Exception: + logger.error(f"FATAL ERROR:\n{traceback.format_exc()}") + raise + +if __name__ == "__main__": + main() +==================================================================================================== +torch.compile + FA3: COMPATIBLE +attention_backend:fa3 +model_params:23209295 effective_depth:11 +world_size:8 grad_accum_steps:1 +EMA enabled: decay=0.997 +SWA enabled: every 50 steps when lr_scale < 0.2 +warmup_step:10/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9321 val_bpb:4.1056 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9330 train_time:89ms step_avg:88.58ms +step:2/20000 train_loss:6.8946 train_time:104ms step_avg:52.18ms +step:3/20000 train_loss:6.7843 train_time:151ms step_avg:50.23ms +step:4/20000 train_loss:6.5824 train_time:197ms step_avg:49.35ms +step:5/20000 train_loss:6.2987 train_time:246ms step_avg:49.11ms +step:6/20000 train_loss:6.1347 train_time:292ms step_avg:48.70ms +step:7/20000 train_loss:5.8964 train_time:339ms step_avg:48.49ms +step:8/20000 train_loss:5.8436 train_time:387ms step_avg:48.35ms +step:9/20000 train_loss:5.7850 train_time:434ms step_avg:48.24ms +step:10/20000 train_loss:5.7466 train_time:482ms step_avg:48.16ms +step:500/20000 train_loss:2.4800 train_time:24113ms step_avg:48.23ms +step:1000/20000 train_loss:2.3689 train_time:48367ms step_avg:48.37ms +step:1500/20000 train_loss:2.2344 train_time:72778ms step_avg:48.52ms +QAT enabled: int6 at step 1852 (fraction=0.15) +step:2000/20000 train_loss:2.2053 train_time:105581ms step_avg:52.79ms +step:2500/20000 train_loss:2.2021 train_time:129951ms step_avg:51.98ms +step:3000/20000 train_loss:3.1688 train_time:154334ms step_avg:51.44ms +step:3000/20000 val_loss:2.1694 val_bpb:1.2848 train_time:154369ms step_avg:51.46ms +step:3500/20000 train_loss:2.3468 train_time:178777ms step_avg:51.08ms +step:4000/20000 train_loss:2.2508 train_time:203171ms step_avg:50.79ms +step:4500/20000 train_loss:1.8515 train_time:227602ms step_avg:50.58ms +step:5000/20000 train_loss:2.2155 train_time:252054ms step_avg:50.41ms +step:5500/20000 train_loss:2.2023 train_time:276469ms step_avg:50.27ms +step:6000/20000 train_loss:2.0951 train_time:300917ms step_avg:50.15ms +step:6000/20000 val_loss:2.1248 val_bpb:1.2584 train_time:300953ms step_avg:50.16ms +step:6500/20000 train_loss:2.0885 train_time:325370ms step_avg:50.06ms +step:7000/20000 train_loss:2.0812 train_time:349779ms step_avg:49.97ms +step:7500/20000 train_loss:2.0987 train_time:374236ms step_avg:49.90ms +step:8000/20000 train_loss:2.0449 train_time:398650ms step_avg:49.83ms +step:8500/20000 train_loss:2.2019 train_time:423118ms step_avg:49.78ms +step:9000/20000 train_loss:2.0883 train_time:447575ms step_avg:49.73ms +step:9000/20000 val_loss:2.0962 val_bpb:1.2415 train_time:447610ms step_avg:49.73ms +step:9500/20000 train_loss:2.0107 train_time:471976ms step_avg:49.68ms +step:10000/20000 train_loss:1.9743 train_time:496428ms step_avg:49.64ms +step:10500/20000 train_loss:1.9789 train_time:520872ms step_avg:49.61ms +step:11000/20000 train_loss:2.0102 train_time:545276ms step_avg:49.57ms +SWA started at step 11450 (lr_scale=0.1884) +step:11500/20000 train_loss:1.8758 train_time:569852ms step_avg:49.55ms +step:12000/20000 train_loss:1.9428 train_time:594795ms step_avg:49.57ms +step:12000/20000 val_loss:1.9664 val_bpb:1.1646 train_time:594831ms step_avg:49.57ms +step:12103/20000 val_loss:1.9638 val_bpb:1.1631 train_time:600040ms step_avg:49.58ms +stopping_early: wallclock_cap train_time:600040ms step:12103/20000 +peak memory: 12569 MiB +Using EMA weights (decay=0.997) +Raw model: 91514419 bytes, code: 75347 bytes +mixed_int6_int8+lzma-6: 15887168 bytes, total: 15962515 bytes +FITS: 15962515 <= 16,000,000 +final_mixed_lzma-6_roundtrip val_loss:1.9645 val_bpb:1.1635 +final_mixed_lzma-6_roundtrip_exact val_loss:1.96449328 val_bpb:1.16348357 +Using EMA weights (decay=0.997) +legal_ttt:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 lr=0.002 epochs=3 freeze_blocks=0 +legal_ttt:params unfrozen=23209295 frozen=0 + legal_ttt_chunk [1/1893] bpb=1.171454 time=0.4s + legal_ttt_chunk [11/1893] bpb=1.162058 time=2.5s + legal_ttt_chunk [21/1893] bpb=1.148567 time=4.7s + legal_ttt_chunk [31/1893] bpb=1.147828 time=6.8s + legal_ttt_chunk [41/1893] bpb=1.134481 time=9.0s + legal_ttt_chunk [51/1893] bpb=1.128362 time=11.1s + legal_ttt_chunk [61/1893] bpb=1.135281 time=13.3s + legal_ttt_chunk [71/1893] bpb=1.134624 time=15.4s + legal_ttt_chunk [81/1893] bpb=1.134320 time=17.6s + legal_ttt_chunk [91/1893] bpb=1.135335 time=19.7s + legal_ttt_chunk [101/1893] bpb=1.139056 time=21.9s + legal_ttt_chunk [111/1893] bpb=1.141180 time=24.1s + legal_ttt_chunk [121/1893] bpb=1.134581 time=26.2s + legal_ttt_chunk [131/1893] bpb=1.134792 time=28.4s + legal_ttt_chunk [141/1893] bpb=1.140411 time=30.6s + legal_ttt_chunk [151/1893] bpb=1.142447 time=32.8s + legal_ttt_chunk [161/1893] bpb=1.142210 time=34.9s + legal_ttt_chunk [171/1893] bpb=1.146717 time=37.1s + legal_ttt_chunk [181/1893] bpb=1.148995 time=39.2s + legal_ttt_chunk [191/1893] bpb=1.156239 time=41.4s + legal_ttt_chunk [201/1893] bpb=1.155479 time=43.5s + legal_ttt_chunk [211/1893] bpb=1.153247 time=45.7s + legal_ttt_chunk [221/1893] bpb=1.154722 time=47.8s + legal_ttt_chunk [231/1893] bpb=1.153329 time=50.0s + legal_ttt_chunk [241/1893] bpb=1.153828 time=52.1s + legal_ttt_chunk [251/1893] bpb=1.153348 time=54.3s + legal_ttt_chunk [261/1893] bpb=1.150403 time=56.5s + legal_ttt_chunk [271/1893] bpb=1.149198 time=58.6s + legal_ttt_chunk [281/1893] bpb=1.150554 time=60.8s + legal_ttt_chunk [291/1893] bpb=1.152311 time=62.9s + legal_ttt_chunk [301/1893] bpb=1.152973 time=65.1s + legal_ttt_chunk [311/1893] bpb=1.155149 time=67.2s + legal_ttt_chunk [321/1893] bpb=1.157076 time=69.4s + legal_ttt_chunk [331/1893] bpb=1.157081 time=71.5s + legal_ttt_chunk [341/1893] bpb=1.156067 time=73.7s + legal_ttt_chunk [351/1893] bpb=1.158401 time=75.9s + legal_ttt_chunk [361/1893] bpb=1.158667 time=78.0s + legal_ttt_chunk [371/1893] bpb=1.157959 time=80.2s + legal_ttt_chunk [381/1893] bpb=1.158066 time=82.3s + legal_ttt_chunk [391/1893] bpb=1.157912 time=84.5s + legal_ttt_chunk [401/1893] bpb=1.155772 time=86.6s + legal_ttt_chunk [411/1893] bpb=1.154661 time=88.8s + legal_ttt_chunk [421/1893] bpb=1.153668 time=90.9s + legal_ttt_chunk [431/1893] bpb=1.153529 time=93.1s + legal_ttt_chunk [441/1893] bpb=1.153843 time=95.3s + legal_ttt_chunk [451/1893] bpb=1.154150 time=97.4s + legal_ttt_chunk [461/1893] bpb=1.153027 time=99.6s + legal_ttt_chunk [471/1893] bpb=1.153693 time=101.7s + legal_ttt_chunk [481/1893] bpb=1.153365 time=103.9s + legal_ttt_chunk [491/1893] bpb=1.152338 time=106.0s + legal_ttt_chunk [501/1893] bpb=1.151912 time=108.2s + legal_ttt_chunk [511/1893] bpb=1.151251 time=110.3s + legal_ttt_chunk [521/1893] bpb=1.149107 time=112.5s + legal_ttt_chunk [531/1893] bpb=1.150281 time=114.6s + legal_ttt_chunk [541/1893] bpb=1.150604 time=116.8s + legal_ttt_chunk [551/1893] bpb=1.149519 time=119.0s + legal_ttt_chunk [561/1893] bpb=1.149960 time=121.1s + legal_ttt_chunk [571/1893] bpb=1.148886 time=123.3s + legal_ttt_chunk [581/1893] bpb=1.148101 time=125.4s + legal_ttt_chunk [591/1893] bpb=1.147435 time=127.6s + legal_ttt_chunk [601/1893] bpb=1.147941 time=129.7s + legal_ttt_chunk [611/1893] bpb=1.147859 time=131.9s + legal_ttt_chunk [621/1893] bpb=1.147708 time=134.1s + legal_ttt_chunk [631/1893] bpb=1.148390 time=136.2s + legal_ttt_chunk [641/1893] bpb=1.148194 time=138.4s + legal_ttt_chunk [651/1893] bpb=1.148359 time=140.5s + legal_ttt_chunk [661/1893] bpb=1.147884 time=142.7s + legal_ttt_chunk [671/1893] bpb=1.148237 time=144.8s + legal_ttt_chunk [681/1893] bpb=1.148893 time=147.0s + legal_ttt_chunk [691/1893] bpb=1.149910 time=149.1s + legal_ttt_chunk [701/1893] bpb=1.149373 time=151.3s + legal_ttt_chunk [711/1893] bpb=1.149388 time=153.5s + legal_ttt_chunk [721/1893] bpb=1.149044 time=155.6s + legal_ttt_chunk [731/1893] bpb=1.149104 time=157.8s + legal_ttt_chunk [741/1893] bpb=1.149175 time=159.9s + legal_ttt_chunk [751/1893] bpb=1.149046 time=162.1s + legal_ttt_chunk [761/1893] bpb=1.148932 time=164.2s + legal_ttt_chunk [771/1893] bpb=1.148640 time=166.4s + legal_ttt_chunk [781/1893] bpb=1.149410 time=168.6s + legal_ttt_chunk [791/1893] bpb=1.149039 time=170.7s + legal_ttt_chunk [801/1893] bpb=1.149322 time=172.9s + legal_ttt_chunk [811/1893] bpb=1.149079 time=175.0s + legal_ttt_chunk [821/1893] bpb=1.148844 time=177.2s + legal_ttt_chunk [831/1893] bpb=1.148671 time=179.3s + legal_ttt_chunk [841/1893] bpb=1.148026 time=181.5s + legal_ttt_chunk [851/1893] bpb=1.147778 time=183.6s + legal_ttt_chunk [861/1893] bpb=1.147547 time=185.8s + legal_ttt_chunk [871/1893] bpb=1.147792 time=187.9s + legal_ttt_chunk [881/1893] bpb=1.147966 time=190.1s + legal_ttt_chunk [891/1893] bpb=1.147577 time=192.3s + legal_ttt_chunk [901/1893] bpb=1.147320 time=194.4s + legal_ttt_chunk [911/1893] bpb=1.147431 time=196.6s + legal_ttt_chunk [921/1893] bpb=1.147910 time=198.7s + legal_ttt_chunk [931/1893] bpb=1.147911 time=200.9s + legal_ttt_chunk [941/1893] bpb=1.147595 time=203.0s + legal_ttt_chunk [951/1893] bpb=1.147992 time=205.2s + legal_ttt_chunk [961/1893] bpb=1.148035 time=207.4s + legal_ttt_chunk [971/1893] bpb=1.148906 time=209.5s + legal_ttt_chunk [981/1893] bpb=1.148964 time=211.7s + legal_ttt_chunk [991/1893] bpb=1.148994 time=213.8s + legal_ttt_chunk [1001/1893] bpb=1.148963 time=216.0s + legal_ttt_chunk [1011/1893] bpb=1.148746 time=218.1s + legal_ttt_chunk [1021/1893] bpb=1.149079 time=220.3s + legal_ttt_chunk [1031/1893] bpb=1.149547 time=222.4s + legal_ttt_chunk [1041/1893] bpb=1.149201 time=224.6s + legal_ttt_chunk [1051/1893] bpb=1.148909 time=226.8s + legal_ttt_chunk [1061/1893] bpb=1.148926 time=228.9s + legal_ttt_chunk [1071/1893] bpb=1.149531 time=231.1s + legal_ttt_chunk [1081/1893] bpb=1.149768 time=233.2s + legal_ttt_chunk [1091/1893] bpb=1.150534 time=235.4s + legal_ttt_chunk [1101/1893] bpb=1.150532 time=237.5s + legal_ttt_chunk [1111/1893] bpb=1.150390 time=239.7s + legal_ttt_chunk [1121/1893] bpb=1.150201 time=241.8s + legal_ttt_chunk [1131/1893] bpb=1.150117 time=244.0s + legal_ttt_chunk [1141/1893] bpb=1.149802 time=246.1s + legal_ttt_chunk [1151/1893] bpb=1.149843 time=248.3s + legal_ttt_chunk [1161/1893] bpb=1.149504 time=250.5s + legal_ttt_chunk [1171/1893] bpb=1.149815 time=252.6s + legal_ttt_chunk [1181/1893] bpb=1.149044 time=254.8s + legal_ttt_chunk [1191/1893] bpb=1.148923 time=256.9s + legal_ttt_chunk [1201/1893] bpb=1.149321 time=259.1s + legal_ttt_chunk [1211/1893] bpb=1.148859 time=261.2s + legal_ttt_chunk [1221/1893] bpb=1.148548 time=263.4s + legal_ttt_chunk [1231/1893] bpb=1.148258 time=265.5s + legal_ttt_chunk [1241/1893] bpb=1.147906 time=267.7s + legal_ttt_chunk [1251/1893] bpb=1.147319 time=269.9s + legal_ttt_chunk [1261/1893] bpb=1.147315 time=272.0s + legal_ttt_chunk [1271/1893] bpb=1.146960 time=274.2s + legal_ttt_chunk [1281/1893] bpb=1.146775 time=276.3s + legal_ttt_chunk [1291/1893] bpb=1.146535 time=278.5s + legal_ttt_chunk [1301/1893] bpb=1.145943 time=280.6s + legal_ttt_chunk [1311/1893] bpb=1.145547 time=282.8s + legal_ttt_chunk [1321/1893] bpb=1.145223 time=284.9s + legal_ttt_chunk [1331/1893] bpb=1.145150 time=287.1s + legal_ttt_chunk [1341/1893] bpb=1.145031 time=289.2s + legal_ttt_chunk [1351/1893] bpb=1.144977 time=291.4s + legal_ttt_chunk [1361/1893] bpb=1.145016 time=293.5s + legal_ttt_chunk [1371/1893] bpb=1.144884 time=295.7s + legal_ttt_chunk [1381/1893] bpb=1.144889 time=297.9s + legal_ttt_chunk [1391/1893] bpb=1.144502 time=300.0s + legal_ttt_chunk [1401/1893] bpb=1.144472 time=302.2s + legal_ttt_chunk [1411/1893] bpb=1.144606 time=304.3s + legal_ttt_chunk [1421/1893] bpb=1.144847 time=306.5s + legal_ttt_chunk [1431/1893] bpb=1.144557 time=308.6s + legal_ttt_chunk [1441/1893] bpb=1.145088 time=310.8s + legal_ttt_chunk [1451/1893] bpb=1.145423 time=312.9s + legal_ttt_chunk [1461/1893] bpb=1.144991 time=315.1s + legal_ttt_chunk [1471/1893] bpb=1.146016 time=317.2s + legal_ttt_chunk [1481/1893] bpb=1.145557 time=319.4s + legal_ttt_chunk [1491/1893] bpb=1.145361 time=321.5s + legal_ttt_chunk [1501/1893] bpb=1.145302 time=323.7s + legal_ttt_chunk [1511/1893] bpb=1.145320 time=325.9s + legal_ttt_chunk [1521/1893] bpb=1.145381 time=328.0s + legal_ttt_chunk [1531/1893] bpb=1.144872 time=330.2s + legal_ttt_chunk [1541/1893] bpb=1.144731 time=332.3s + legal_ttt_chunk [1551/1893] bpb=1.145022 time=334.5s + legal_ttt_chunk [1561/1893] bpb=1.145022 time=336.6s + legal_ttt_chunk [1571/1893] bpb=1.144860 time=338.8s + legal_ttt_chunk [1581/1893] bpb=1.145004 time=340.9s + legal_ttt_chunk [1591/1893] bpb=1.144852 time=343.1s + legal_ttt_chunk [1601/1893] bpb=1.145040 time=345.2s + legal_ttt_chunk [1611/1893] bpb=1.144969 time=347.4s + legal_ttt_chunk [1621/1893] bpb=1.144595 time=349.5s + legal_ttt_chunk [1631/1893] bpb=1.144901 time=351.7s + legal_ttt_chunk [1641/1893] bpb=1.144928 time=353.8s + legal_ttt_chunk [1651/1893] bpb=1.144873 time=356.0s + legal_ttt_chunk [1661/1893] bpb=1.144750 time=358.2s + legal_ttt_chunk [1671/1893] bpb=1.145223 time=360.3s + legal_ttt_chunk [1681/1893] bpb=1.145367 time=362.5s + legal_ttt_chunk [1691/1893] bpb=1.145196 time=364.6s + legal_ttt_chunk [1701/1893] bpb=1.145322 time=366.8s + legal_ttt_chunk [1711/1893] bpb=1.145313 time=368.9s + legal_ttt_chunk [1721/1893] bpb=1.145291 time=371.1s + legal_ttt_chunk [1731/1893] bpb=1.145162 time=373.2s + legal_ttt_chunk [1741/1893] bpb=1.144965 time=375.4s + legal_ttt_chunk [1751/1893] bpb=1.144785 time=377.5s + legal_ttt_chunk [1761/1893] bpb=1.144925 time=379.7s + legal_ttt_chunk [1771/1893] bpb=1.144826 time=381.9s + legal_ttt_chunk [1781/1893] bpb=1.144846 time=384.0s + legal_ttt_chunk [1791/1893] bpb=1.144440 time=386.2s + legal_ttt_chunk [1801/1893] bpb=1.144312 time=388.3s + legal_ttt_chunk [1811/1893] bpb=1.144223 time=390.5s + legal_ttt_chunk [1821/1893] bpb=1.144287 time=392.6s + legal_ttt_chunk [1831/1893] bpb=1.143693 time=394.8s + legal_ttt_chunk [1841/1893] bpb=1.143803 time=396.9s + legal_ttt_chunk [1851/1893] bpb=1.143596 time=399.1s + legal_ttt_chunk [1861/1893] bpb=1.143223 time=401.2s + legal_ttt_chunk [1871/1893] bpb=1.143195 time=403.4s + legal_ttt_chunk [1881/1893] bpb=1.142755 time=405.5s + legal_ttt_chunk [1891/1893] bpb=1.142514 time=407.7s + legal_ttt_chunk [1893/1893] bpb=1.142561 time=408.0s +legal_ttt:done val_loss=1.925313 val_bpb=1.140282 elapsed=408.0s +final_legal_ttt val_loss:1.9253 val_bpb:1.1403 eval_time:408444ms diff --git a/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed42.log b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed42.log new file mode 100644 index 0000000000..c5b9664be7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-31_UnifiedAttention_FA3_LegalTTT/train_seed42.log @@ -0,0 +1,2760 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import traceback +import uuid +import zlib +import lzma +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_func +except ImportError: + raise ImportError( + "Flash Attention 3 (Hopper) is required. Install with:\n" + " pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280\n" + "Or see requirements.txt for details." + ) + +# ───────────────────────────────────────────────────────────── +# LOGGING UTILITIES +# ───────────────────────────────────────────────────────────── + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("yocto-golf") + + +def log_architecture(model: nn.Module, args: "Hyperparameters") -> None: + """Log full architecture details for debugging.""" + n_params = sum(p.numel() for p in model.parameters()) + n_unique = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info("=" * 60) + logger.info("YOCTO PARAMETER GOLF — ARCHITECTURE") + logger.info("=" * 60) + logger.info(f" embed_dim (d): {args.model_dim}") + logger.info(f" d/3 (component_dim): {args.model_dim // 3}") + logger.info(f" num_heads: {args.num_heads}") + logger.info(f" head_dim: {args.model_dim // 3 // args.num_heads}") + rope_dim = int(args.model_dim // 3 // args.num_heads * args.rope_fraction) + rope_dim = max(rope_dim - (rope_dim % 2), 2) + logger.info(f" RoPE fraction: {args.rope_fraction} ({rope_dim}/{args.model_dim // 3 // args.num_heads} dims)") + logger.info(f" unique_layers (K): {args.num_unique_layers}") + logger.info(f" recurrences (R): {args.num_recurrences}") + logger.info(f" effective_depth: {args.num_unique_layers * args.num_recurrences}") + logger.info(f" FFN type: LeakyReLU(0.5)²") + logger.info(f" FFN hidden_dim: {args.mlp_mult * args.model_dim}") + logger.info(f" vocab_size: {args.vocab_size}") + logger.info(f" tie_embeddings: {args.tie_embeddings}") + logger.info(f" total_params: {n_params:,}") + logger.info(f" trainable_params: {n_unique:,}") + logger.info("=" * 60) + + # Per-layer breakdown + for name, mod in model.named_modules(): + if isinstance(mod, UnifiedAttention): + attn_params = sum(p.numel() for p in mod.parameters()) + logger.info(f" UnifiedAttention params: {attn_params:,}") + break + for name, mod in model.named_modules(): + if isinstance(mod, SquaredReLUMLP): + mlp_params = sum(p.numel() for p in mod.parameters()) + logger.info(f" MLP params: {mlp_params:,}") + break + + +# ───────────────────────────────────────────────────────────── +# HYPERPARAMETERS +# ───────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # ── Yocto architecture ── + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 552)) + num_heads = int(os.environ.get("NUM_HEADS", 4)) + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 10)) + num_recurrences = int(os.environ.get("NUM_RECURRENCES", 1)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + use_swiglu = bool(int(os.environ.get("USE_SWIGLU", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + seeking_gain_init = float(os.environ.get("SEEKING_GAIN_INIT", 1.5)) + rope_fraction = float(os.environ.get("ROPE_FRACTION", 1.0)) # 1.0 = full RoPE, 0.5 = half partial RoPE + + # ── Optimizer ── + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # ── LR warmup (actual learning rate ramp, separate from compile warmup) ── + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 100)) + + # ── EMA ── + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # 0 = disabled, 0.997 = SOTA setting + + # ── SWA (Stochastic Weight Averaging) ── + swa_every = int(os.environ.get("SWA_EVERY", 50)) # 0 = disabled, 50 = SOTA setting + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) # only SWA when lr_scale < this + + # ── Compression ── + compression = os.environ.get("COMPRESSION", "lzma") # "zlib", "zstd", or "lzma" + + # ── QAT (Quantization-Aware Training) ── + qat_bits = int(os.environ.get("QAT_BITS", 6)) # 0 = disabled, 6 = int6 QAT + qat_start_fraction = float(os.environ.get("QAT_START_FRACTION", 0.15)) # when to start QAT + + # ── Mixed precision quantization ── + # Layers listed here use int5 (max_val=15); all others use int6 (max_val=31) + # For K=11: middle layers 2-8 are least sensitive, first/last/VE layers stay int6 + int5_layers = os.environ.get("INT5_LAYERS", "") # e.g. "2,3,4,5,6,7,8" + + # ── BigramHash ── + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) # 0 = disabled, e.g. 3072 + bigram_dim = int(os.environ.get("BIGRAM_DIM", 112)) # embedding dim per bigram entry + + # ── Sliding window eval ── + sliding_window_eval = bool(int(os.environ.get("SLIDING_WINDOW_EVAL", "1"))) + sliding_window_stride = int(os.environ.get("SLIDING_WINDOW_STRIDE", 64)) + sliding_window_seq_len = int(os.environ.get("SLIDING_WINDOW_SEQ_LEN", 0)) # 0 = use train_seq_len + + # ── LN Scale ── + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # 1/sqrt(layer_idx+1) on norm outputs + + # ── Value Embedding (VE128) ── + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "8,9") # last 2 of 10 layers + + # ── TTT LoRA ── + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_lora_attn = bool(int(os.environ.get("TTT_LORA_ATTN", "1"))) # 0 = LM head only + + # ── Legal Score-First TTT ── + legal_ttt_enabled = bool(int(os.environ.get("LEGAL_TTT_ENABLED", "1"))) + legal_ttt_lr = float(os.environ.get("LEGAL_TTT_LR", 0.002)) + legal_ttt_epochs = int(os.environ.get("LEGAL_TTT_EPOCHS", 3)) + legal_ttt_chunk_tokens = int(os.environ.get("LEGAL_TTT_CHUNK_TOKENS", 32768)) + legal_ttt_freeze_blocks = int(os.environ.get("LEGAL_TTT_FREEZE_BLOCKS", 0)) + legal_ttt_momentum = float(os.environ.get("LEGAL_TTT_MOMENTUM", 0.9)) + legal_ttt_batch_seqs = int(os.environ.get("LEGAL_TTT_BATCH_SEQS", 32)) + legal_ttt_grad_clip = float(os.environ.get("LEGAL_TTT_GRAD_CLIP", 1.0)) + + @property + def num_effective_layers(self) -> int: + return self.num_unique_layers * self.num_recurrences + + def validate(self) -> None: + """Check all divisibility constraints.""" + d = self.model_dim + assert d % 3 == 0, f"model_dim={d} must be divisible by 3 for unified attention split" + comp = d // 3 + assert comp % self.num_heads == 0, ( + f"component_dim={comp} (model_dim/3) must be divisible by num_heads={self.num_heads}" + ) + head_dim = comp // self.num_heads + assert head_dim % 2 == 0, f"head_dim={head_dim} must be even for RoPE" + assert head_dim >= 16, f"head_dim={head_dim} must be >= 16 for useful RoPE (got {head_dim})" + assert self.logit_softcap > 0, f"logit_softcap must be positive" + logger.info(f"Architecture constraints validated: d={d}, comp={comp}, heads={self.num_heads}, " + f"head_dim={head_dim}, RoPE_pairs={head_dim//2}") + + +# ───────────────────────────────────────────────────────────── +# MUON OPTIMIZER (from baseline, unchanged) +# ───────────────────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + """Parallel Muon: batched Newton-Schulz on 3D parameter banks. + + Phase 1: launch_reduce_scatters() — async reduce-scatter for all banks + Phase 2: (caller does Adam on small params while RS is in-flight) + Phase 3: step() — wait for RS, local NS5, async all-gather + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending — launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, batched NS5, all-gather.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + # Wait for previous all-gather and apply update + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + # Get gradient (sharded or full) + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + # Momentum + Nesterov + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + # Batched Newton-Schulz on 3D tensor + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + # Apply last all-gather + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + + +# ───────────────────────────────────────────────────────────── +# BPB EVALUATION (from baseline, unchanged) +# ───────────────────────────────────────────────────────────── + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split too short for seq_len={seq_len}") + return tokens[: usable + 1] + + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window(args, model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + """Sliding window evaluation: overlapping windows, score only the stride region.""" + sw_seq_len = args.sliding_window_seq_len if args.sliding_window_seq_len > 0 else args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + + # Distribute token range across ranks + rank_start = (total_tokens * rank) // world_size + rank_end = (total_tokens * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Get underlying model (unwrap DDP/compile) + base = model.module if hasattr(model, 'module') else model + if hasattr(base, '_orig_mod'): + base = base._orig_mod # unwrap torch.compile + + base.eval() + with torch.inference_mode(): + pos = rank_start + while pos < rank_end: + # Context window ends at pos + stride, starts up to sw_seq_len before that + score_end_global = min(pos + stride, rank_end) + ctx_start_global = max(0, score_end_global - sw_seq_len) + ctx_len = score_end_global - ctx_start_global + + # Where in this window do we start scoring? + score_start_local = pos - ctx_start_global + score_end_local = score_end_global - ctx_start_global + score_len = score_end_local - score_start_local + + if score_len <= 0 or ctx_len < 1: + pos += stride + continue + + # Load tokens: need ctx_len + 1 for (input, target) pairs + chunk = val_tokens[ctx_start_global: ctx_start_global + ctx_len + 1].to( + device=device, dtype=torch.int64, non_blocking=True) + if chunk.numel() < 2: + pos += stride + continue + + x = chunk[:-1].unsqueeze(0) # [1, ctx_len] + y = chunk[1:].unsqueeze(0) # [1, ctx_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base.forward_logits(x) + + # Per-token cross entropy + per_token_loss = F.cross_entropy( + logits.float().reshape(-1, logits.size(-1)), + y.reshape(-1), + reduction="none" + ) # [ctx_len] + + # Score only the stride region + scored_losses = per_token_loss[score_start_local:score_end_local] + scored_x = x[0, score_start_local:score_end_local] + scored_y = y[0, score_start_local:score_end_local] + + val_loss_sum += scored_losses.to(torch.float64).sum() + val_token_count += score_len + + tok_bytes = base_bytes_lut[scored_y].to(torch.float64) + tok_bytes += (has_leading_space_lut[scored_y] & ~is_boundary_token_lut[scored_x]).to(torch.float64) + val_byte_count += tok_bytes.sum() + + pos += stride + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ───────────────────────────────────────────────────────────── +# QUANTIZATION (from baseline, unchanged) +# ───────────────────────────────────────────────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = ("attn_scale", "mlp_scale", "resid_mix", "skip_weight", "seeking_gain", "smear", "ve_layer_scales", "ve_shared.scale") +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 99.99984 / 100.0 + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def quantize_float_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def quantize_state_dict_int8(state_dict): + quantized, scales, dtypes, passthrough = {}, {}, {}, {} + passthrough_orig_dtypes, qmeta = {}, {} + stats = dict.fromkeys(("param_count", "num_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + kept = t.float().contiguous() + elif t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + kept = t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + else: + kept = t + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj = {"__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, "scales": scales, + "dtypes": dtypes, "passthrough": passthrough} + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(obj): + out = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ───────────────────────────────────────────────────────────── +# INT6 QUANTIZATION — with GPTQ-lite optimal clip search +# ───────────────────────────────────────────────────────────── + +GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + + +def quantize_float_tensor_int6(t: Tensor): + """Per-row int6 quantization with GPTQ-lite clip search. + Tries multiple clip percentiles and picks the one minimizing + GLOBAL reconstruction MSE (same percentile for all rows). + Range [-31, 31], stored in int8. + """ + return quantize_float_tensor_intN(t, max_val=31) + + +def quantize_float_tensor_intN(t: Tensor, max_val: int = 31): + """Per-row intN quantization with GPTQ-lite clip search. + max_val=31 for int6, max_val=15 for int5, max_val=7 for int4. + """ + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_err = None, None, float('inf') + + for pct in GPTQ_CLIP_PERCENTILES: + if pct >= 1.0: + clip_abs = t32.abs().amax(dim=1).clamp_min(1e-8) + else: + clip_abs = torch.quantile(t32.abs(), pct, dim=1).clamp_min(1e-8) + scale = (clip_abs / max_val).clamp_min(1e-8).to(torch.float16) + clipped = t32.clamp(-clip_abs[:, None], clip_abs[:, None]) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -max_val, max_val).to(torch.int8) + # Global MSE (single percentile for entire matrix) + recon = q.float() * scale.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_scale, best_err = q, scale, err + + return best_q.contiguous(), best_scale.contiguous() + # 1D fallback + abs_max = t32.abs().max().clamp_min(1e-8).item() + scale = torch.tensor(abs_max / max_val, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -max_val, max_val).to(torch.int8) + return q, scale + + +# ── Unbank/rebank for quantization ── + +def _unbank_state_dict(sd, num_layers): + """Convert 3D bank tensors into individual 2D tensors for quantization. + CRITICAL: splits unified_bank into 3 separate band matrices (seeking, offering, content) + so each band gets its own optimal clip percentile during GPTQ-lite quantization.""" + out = {} + for name, tensor in sd.items(): + if name == "unified_bank": + for i in range(num_layers): + w = tensor[i] # [d, d] + d = w.shape[0] + comp = d // 3 + # Split into three functional bands for independent quantization + out[f"blocks.{i}.attn.W_seeking.weight"] = w[:comp, :] + out[f"blocks.{i}.attn.W_offering.weight"] = w[comp:2*comp, :] + out[f"blocks.{i}.attn.W_content.weight"] = w[2*comp:, :] + elif name == "output_bank": + for i in range(num_layers): + out[f"blocks.{i}.attn.W_output.weight"] = tensor[i] + elif name == "fc_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "proj_bank": + for i in range(num_layers): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + + +def _rebank_state_dict(sd, num_layers, template_sd): + """Convert individual 2D tensors back into 3D bank tensors. + Recombines three band matrices back into unified_bank.""" + out = {} + consumed = set() + + # Reconstruct unified_bank from three bands + unified_slices = [] + for i in range(num_layers): + sk = f"blocks.{i}.attn.W_seeking.weight" + ok = f"blocks.{i}.attn.W_offering.weight" + ck = f"blocks.{i}.attn.W_content.weight" + unified_slices.append(torch.cat([sd[sk], sd[ok], sd[ck]], dim=0)) + consumed.update([sk, ok, ck]) + out["unified_bank"] = torch.stack(unified_slices).to(dtype=template_sd["unified_bank"].dtype) + + for bank_name, key_template in [ + ("output_bank", "blocks.{i}.attn.W_output.weight"), + ("fc_bank", "blocks.{i}.mlp.fc.weight"), + ("proj_bank", "blocks.{i}.mlp.proj.weight"), + ]: + slices = [] + for i in range(num_layers): + k = key_template.format(i=i) + slices.append(sd[k]) + consumed.add(k) + out[bank_name] = torch.stack(slices).to(dtype=template_sd[bank_name].dtype) + + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + + +# Embedding name patterns — these stay at int8, everything else gets int6 +INT8_EMBED_PATTERNS = ("tok_emb.", "ve_shared.embed.") + + +def pack_int6_raw(q: Tensor) -> bytes: + """Pack int6 values (stored in int8, range [-31,31]) into 6-bit packed bytes. + 4 values → 3 bytes: [aaaaaabb|bbbbcccc|ccdddddd]""" + flat = q.reshape(-1).numpy().astype(np.int8) + # Offset to unsigned: [-31,31] -> [0,62] + flat_u = (flat.astype(np.int16) + 31).astype(np.uint8) + # Pad to multiple of 4 + pad = (4 - len(flat_u) % 4) % 4 + if pad > 0: + flat_u = np.concatenate([flat_u, np.zeros(pad, dtype=np.uint8)]) + a, b, c, d = flat_u[0::4], flat_u[1::4], flat_u[2::4], flat_u[3::4] + byte0 = ((a & 0x3F) << 2) | ((b >> 4) & 0x03) + byte1 = ((b & 0x0F) << 4) | ((c >> 2) & 0x0F) + byte2 = ((c & 0x03) << 6) | (d & 0x3F) + packed = np.empty(len(a) * 3, dtype=np.uint8) + packed[0::3] = byte0 + packed[1::3] = byte1 + packed[2::3] = byte2 + return packed.tobytes() + + +def unpack_int6_raw(data: bytes, numel: int) -> Tensor: + """Unpack 6-bit packed bytes back to int8 tensor (range [-31,31]).""" + packed = np.frombuffer(data, dtype=np.uint8) + a = (packed[0::3] >> 2) & 0x3F + b = ((packed[0::3] & 0x03) << 4) | ((packed[1::3] >> 4) & 0x0F) + c = ((packed[1::3] & 0x0F) << 2) | ((packed[2::3] >> 6) & 0x03) + d = packed[2::3] & 0x3F + flat = np.empty(len(a) * 4, dtype=np.uint8) + flat[0::4], flat[1::4], flat[2::4], flat[3::4] = a, b, c, d + flat_s = flat[:numel].astype(np.int16) - 31 + return torch.tensor(flat_s.astype(np.int8), dtype=torch.int8) + + +def serialize_artifact_raw(state_dict, compress="lzma"): + """Serialize model to raw binary format with int6 bit-packing. + No torch.save, no pickle, no ZIP overhead. + + Format per tensor entry: + [2B name_len][name][1B quant_type][1B ndim][4B×ndim shape][4B data_len][data] + quant_type: 0=int6_packed, 1=int8, 2=fp16, 3=fp32 + For quantized (type 0,1): [4B scale_len][scale_data] follows + """ + import struct + + buf = io.BytesIO() + # Magic + count + entries = list(state_dict.items()) + buf.write(struct.pack('<4sI', b'YCT1', len(entries))) + + for name, tensor in entries: + t = tensor.detach().to("cpu").contiguous() + name_bytes = name.encode('utf-8') + buf.write(struct.pack(' 0 and len(shape) == 2: + state_dict[name] = (q.float() * scale.float().view(shape[0], 1)).to(torch.bfloat16).contiguous() + else: + state_dict[name] = (q.float() * scale.float()).to(torch.bfloat16).contiguous() + elif qtype == 1: # int8 + scale_len = struct.unpack(' 0 and len(shape) == 2: + state_dict[name] = (q.float() * scale.float().view(shape[0], 1)).to(torch.bfloat16).contiguous() + else: + state_dict[name] = (q.float() * scale.float()).to(torch.bfloat16).contiguous() + elif qtype == 2: # fp16 + state_dict[name] = torch.frombuffer(bytearray(data), dtype=torch.float16).clone().reshape(shape).contiguous() + elif qtype == 3: # fp32 + state_dict[name] = torch.frombuffer(bytearray(data), dtype=torch.float32).clone().reshape(shape).contiguous() + + return state_dict + + +def quantize_state_dict_mixed(state_dict, int5_layers=None): + """Mixed int5/int6/int8 quantization with flat key format. + int5_layers: set of layer indices that use int5 (max_val=15). + Int8 for embeddings, int6 or int5 for block weights depending on layer.""" + if int5_layers is None: + int5_layers = set() + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + result[name] = t.float().contiguous() + meta[name] = "passthrough_ctrl" + else: + result[name] = t.to(torch.float16).contiguous() + meta[name] = "passthrough" + continue + # Int8 for embeddings + is_embed = any(p in name for p in INT8_EMBED_PATTERNS) + if is_embed: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + # Check if this weight belongs to an int5 layer + # Unbanked names: blocks.{i}.attn.W_unified.weight, blocks.{i}.mlp.fc.weight, etc. + layer_idx = -1 + if "blocks." in name: + try: + layer_idx = int(name.split("blocks.")[1].split(".")[0]) + except (ValueError, IndexError): + pass + if layer_idx in int5_layers: + q, s = quantize_float_tensor_intN(t, max_val=15) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = quantize_float_tensor_int6(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + return result, meta + + +def dequantize_state_dict_mixed(result, meta, template_sd=None): + """Dequantize flat-key mixed int6/int8 state dict back to float tensors.""" + out = {} + for name, info in meta.items(): + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + # Restore original dtype if template available + if template_sd is not None and name in template_sd: + orig_dtype = template_sd[name].dtype + if t.dtype != orig_dtype: + t = t.to(orig_dtype) + out[name] = t + continue + q = result[name + ".q"] + s = result[name + ".scale"] + if s.ndim > 0: + deq = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + deq = q.float() * float(s.item()) + # Restore to original dtype + target_dtype = torch.bfloat16 + if template_sd is not None and name in template_sd: + target_dtype = template_sd[name].dtype + out[name] = deq.to(target_dtype).contiguous() + return out + + + +# ───────────────────────────────────────────────────────────── +# DATA LOADING (from baseline, unchanged) +# ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ───────────────────────────────────────────────────────────── +# TRANSFORMER MODULES — YOCTO UNIFIED ATTENTION +# ───────────────────────────────────────────────────────────── + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_mlp = False # kept for compatibility + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.training and _qat_active and w.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + w = _fake_quantize(w, _qat_bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +# ── QAT globals (set during training) ── +_qat_active = False +_qat_bits = 6 + + +def _fake_quantize(w: Tensor, bits: int) -> Tensor: + """ + Straight-through estimator for quantization-aware training. + Forward: returns quantized-then-dequantized weights (simulates int-N). + Backward: gradient passes through as if no quantization (STE). + Per-row symmetric quantization to match our compression scheme. + """ + max_val = (1 << (bits - 1)) - 1 # e.g. int6: max_val = 31 + with torch.no_grad(): + # Per-row scale + abs_max = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + scale = abs_max / max_val + # Quantize and dequantize + w_q = (w / scale).round().clamp(-max_val, max_val) * scale + # STE: use quantized values in forward, but gradient flows to original w + return w + (w_q - w).detach() + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class SmearGate(nn.Module): + """Position-mixing gate: blends each position with its predecessor. + Sigmoid gate initialized at 0 → starts as 50/50 blend, learns optimal per-dim.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Shared embedding table maps vocab tokens to low-dim vectors, projected to target dim.""" + def __init__(self, vocab_size: int, ve_dim: int, target_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, target_dim, bias=False) if ve_dim != target_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class UnifiedAttention(nn.Module): + """ + Yocto Unified Attention: single W_unified → split into [seeking|offering|content]. + 67% fewer attention parameters than standard Q/K/V. + + Weights are passed in from parameter banks (not owned by this module). + """ + def __init__(self, dim: int, num_heads: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0): + super().__init__() + assert dim % 3 == 0, f"dim={dim} must be divisible by 3" + self.dim = dim + self.num_heads = num_heads + self.component_dim = dim // 3 + self.head_dim = self.component_dim // num_heads + assert self.component_dim % num_heads == 0 + + # Partial RoPE: only rotate rope_dim dimensions per head + self.rope_dim = int(self.head_dim * rope_fraction) + self.rope_dim = max(self.rope_dim - (self.rope_dim % 2), 2) + self.pass_dim = self.head_dim - self.rope_dim + + # NO weight matrices here — they come from banks + # Learnable per-head gain for seeking + self.seeking_gain = nn.Parameter( + torch.full((num_heads,), seeking_gain_init, dtype=torch.float32) + ) + self.rotary = Rotary(self.rope_dim, base=rope_base) + + def forward(self, x: Tensor, unified_w: Tensor, output_w: Tensor, unified_delta=None, v_embed=None) -> Tensor: + bsz, seqlen, _ = x.shape + + # Unified projection using bank weight + unified = F.linear(x, unified_w.to(x.dtype)) + if unified_delta is not None: + unified = unified + unified_delta + + # Split into three bands: [seeking | offering | content] + seeking, offering, content = unified.split(self.component_dim, dim=-1) + + # Value embedding: reinject token identity into content band + if v_embed is not None: + content = content + v_embed + + def to_heads(t): + return t.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + + seeking = to_heads(seeking) + offering = to_heads(offering) + content = to_heads(content) + + seeking = F.rms_norm(seeking, (seeking.size(-1),)) + offering = F.rms_norm(offering, (offering.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, seeking.dtype) + if self.pass_dim > 0: + s_rope, s_pass = seeking[..., :self.rope_dim], seeking[..., self.rope_dim:] + o_rope, o_pass = offering[..., :self.rope_dim], offering[..., self.rope_dim:] + s_rope = apply_rotary_emb(s_rope, cos, sin) + o_rope = apply_rotary_emb(o_rope, cos, sin) + seeking = torch.cat([s_rope, s_pass], dim=-1) + offering = torch.cat([o_rope, o_pass], dim=-1) + else: + seeking = apply_rotary_emb(seeking, cos, sin) + offering = apply_rotary_emb(offering, cos, sin) + + seeking = seeking * self.seeking_gain.to(dtype=seeking.dtype)[None, :, None, None] + + # FA3 (Hopper) expects [B, T, H, D] with head_dim multiple of 8 + sq = seeking.transpose(1, 2) + of = offering.transpose(1, 2) + ct = content.transpose(1, 2) + dtype = sq.dtype + if dtype not in (torch.float16, torch.bfloat16): + sq, of, ct = sq.to(torch.bfloat16), of.to(torch.bfloat16), ct.to(torch.bfloat16) + # Pad head_dim to multiple of 8 if needed + hd = sq.size(-1) + pad_n = (8 - hd % 8) % 8 + if pad_n > 0: + sq = F.pad(sq, (0, pad_n)) + of = F.pad(of, (0, pad_n)) + ct = F.pad(ct, (0, pad_n)) + out = _flash_attn_func(sq, of, ct, causal=True) + y = out[0] if isinstance(out, tuple) else out + if pad_n > 0: + y = y[..., :hd] + if y.dtype != dtype: + y = y.to(dtype) + y = y.transpose(1, 2) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, self.component_dim) + return F.linear(y, output_w.to(x.dtype)) + + +class SquaredReLUMLP(nn.Module): + """LeakyReLU(0.5)² MLP — weights passed from banks.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # NO weight matrices here — they come from banks + + def forward(self, x: Tensor, fc_w: Tensor, proj_w: Tensor) -> Tensor: + return F.linear( + F.leaky_relu(F.linear(x, fc_w.to(x.dtype)), negative_slope=0.5).square(), + proj_w.to(x.dtype) + ) + + +class Block(nn.Module): + """Single transformer block with unified attention + MLP. Weights from banks.""" + def __init__(self, dim: int, num_heads: int, mlp_mult: int, rope_base: float, + seeking_gain_init: float, rope_fraction: float = 1.0, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = UnifiedAttention(dim, num_heads, rope_base, seeking_gain_init, rope_fraction) + self.mlp = SquaredReLUMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, unified_w: Tensor, output_w: Tensor, + fc_w: Tensor, proj_w: Tensor, unified_delta_fn=None, v_embed=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) * self.ln_scale_factor + ud = unified_delta_fn(n) if unified_delta_fn is not None else None + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(n, unified_w, output_w, ud, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor, fc_w, proj_w) + return x + + +class BigramHash(nn.Module): + """ + Hash-based bigram embedding: maps (prev_token, current_token) pairs to + a learned embedding via hashing. Provides the model with immediate + predecessor context without needing attention. + + hash(prev, curr) = (prev * 104729 + curr) % vocab_size + Output is projected from bigram_dim to model_dim. + """ + PRIME = 104729 # Large prime for hash mixing + + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.table = nn.Embedding(bigram_vocab_size, bigram_dim) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + # Initialize small so bigram signal doesn't dominate early training + nn.init.normal_(self.table.weight, mean=0.0, std=0.01) + if self.proj is not None: + nn.init.normal_(self.proj.weight, mean=0.0, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + """ + input_ids: [bsz, seq_len] + Returns: [bsz, seq_len, model_dim] bigram embedding to add to token embedding + """ + bsz, seq_len = input_ids.shape + # Shift input_ids to get previous tokens (first position has no predecessor, use 0) + prev_ids = torch.zeros_like(input_ids) + prev_ids[:, 1:] = input_ids[:, :-1] + # Hash: (prev * PRIME + curr) % table_size + hash_ids = (prev_ids.long() * self.PRIME + input_ids.long()) % self.bigram_vocab_size + emb = self.table(hash_ids) # [bsz, seq, bigram_dim] + if self.proj is not None: + emb = self.proj(emb) + return emb + + +class YoctoGPT(nn.Module): + """ + Yocto GPT with unified attention, depth recurrence, and parameter banks. + K unique blocks recycled R times with U-Net skip connections. + 4 parameter banks for batched Newton-Schulz in Parallel Muon. + """ + def __init__(self, vocab_size: int, model_dim: int, num_heads: int, + num_unique_layers: int, num_recurrences: int, mlp_mult: int, + tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, seeking_gain_init: float, + rope_fraction: float = 1.0, + bigram_vocab_size: int = 0, bigram_dim: int = 112, + ln_scale: bool = True, + ve_enabled: bool = True, ve_dim: int = 128, ve_layers: str = "8,9", + int5_layers: str = ""): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_unique_layers = num_unique_layers + self.num_recurrences = num_recurrences + self.int5_layer_set = set(int(x) for x in int5_layers.split(",") if x.strip()) + effective = num_unique_layers * num_recurrences + + comp_dim = model_dim // 3 + mlp_dim = mlp_mult * model_dim + + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHash(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + + # Parameter banks: contiguous 3D tensors for batched Muon + K = num_unique_layers + self.unified_bank = nn.Parameter(torch.empty(K, model_dim, model_dim)) # W_unified: d→d + self.output_bank = nn.Parameter(torch.empty(K, model_dim, comp_dim)) # W_output: comp→d (F.linear expects [out, in]) + self.fc_bank = nn.Parameter(torch.empty(K, mlp_dim, model_dim)) # MLP fc: d→mlp_dim + self.proj_bank = nn.Parameter(torch.empty(K, model_dim, mlp_dim)) # MLP proj: mlp_dim→d + + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, mlp_mult, rope_base, seeking_gain_init, rope_fraction, + layer_idx=k, ln_scale=ln_scale) + for k in range(num_unique_layers) + ]) + + # U-Net skip connections over effective depth + self.num_encoder_layers = effective // 2 + self.num_decoder_layers = effective - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Value Embedding (VE128): shared embedding + per-layer scales + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, comp_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights(tied_embed_init_std) + + def _init_weights(self, std: float) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=std) + K = self.num_unique_layers + proj_scale = 1.0 / math.sqrt(2 * K * self.num_recurrences) + for i in range(K): + nn.init.orthogonal_(self.unified_bank.data[i], gain=1.0) + nn.init.zeros_(self.output_bank.data[i]) + self.output_bank.data[i].mul_(proj_scale) + nn.init.orthogonal_(self.fc_bank.data[i], gain=1.0) + nn.init.zeros_(self.proj_bank.data[i]) + self.proj_bank.data[i].mul_(proj_scale) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _qat_weight(self, w: Tensor, layer_idx: int = -1) -> Tensor: + """Apply QAT fake quantization to a bank weight slice if active. + Uses int5 (bits=5) for layers in int5_layer_set, int6 otherwise.""" + if self.training and _qat_active: + bits = 5 if layer_idx in self.int5_layer_set else _qat_bits + return _fake_quantize(w, bits) + return w + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ud_fn = lora.unified_loras[k] if (lora and lora.unified_loras is not None) else None + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self._qat_weight(self.unified_bank[k], k), + self._qat_weight(self.output_bank[k], k), + self._qat_weight(self.fc_bank[k], k), + self._qat_weight(self.proj_bank[k], k), + ud_fn, v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + eff_layer_idx = 0 + for _r in range(self.num_recurrences): + for k in range(self.num_unique_layers): + is_encoder = eff_layer_idx < self.num_encoder_layers + + if not is_encoder and skips: + dec_idx = eff_layer_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + + ve = self._get_ve(k, input_ids, ve_cache) + x = self.blocks[k](x, x0, + self.unified_bank[k], self.output_bank[k], + self.fc_bank[k], self.proj_bank[k], + v_embed=ve) + + if is_encoder: + skips.append(x) + + eff_layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + +# ───────────────────────────────────────────────────────────── +# TTT LORA — adapted for unified attention +# ───────────────────────────────────────────────────────────── + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + + +class BatchedTTTLoRA(nn.Module): + """ + LoRA for Yocto: adapts W_unified at each UNIQUE block + LM head. + + Structurally correct for depth recurrence: the same LoRA is applied + to a block across all its recurrences, matching how the base weights + are shared. This prevents different recurrence passes from pulling + the shared block in conflicting directions. + + Previous (broken): 12 LoRAs for 12 effective layers (4 blocks × 3 recurrences) + Fixed: 4 LoRAs for 4 unique blocks, reused across recurrences + + use_attn_lora=False: LM head LoRA only (no attention adaptation) + """ + def __init__(self, bsz: int, model: YoctoGPT, rank: int, use_attn_lora: bool = True): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + num_unique = model.num_unique_layers + self.num_unique_layers = num_unique + self.use_attn_lora = use_attn_lora + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + # One LoRA per UNIQUE block (not per effective layer) + if use_attn_lora: + self.unified_loras = nn.ModuleList([ + BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_unique) + ]) + else: + self.unified_loras = None + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + + +def _build_ttt_optimizer(lora, args): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + + +def _find_docs(all_tokens, include_next_bos=True): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb(ptl, x, y, batch_i, chunk_offset, chunk_len, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + loss_sum, byte_sum, token_count): + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + + +def eval_val_ttt_lora(args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size, eval_seq_len = args.ttt_chunk_size, args.ttt_eval_seq_len + batch_size, lora_rank = args.ttt_batch_size, args.ttt_lora_rank + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + use_attn = getattr(args, 'ttt_lora_attn', True) + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank, use_attn_lora=use_attn).to(device) + opt = _build_ttt_optimizer(lora, args) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank, use_attn_lora=use_attn).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [dl - 1 for _, dl in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1].to(dtype=torch.int64, device=device) + x[b, :wl] = chunk[:-1] + y[b, :wl] = chunk[1:] + doc_info.append((co, cl)) + + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb(ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + + +# ───────────────────────────────────────────────────────────── +# LEGAL SCORE-FIRST TTT +# ───────────────────────────────────────────────────────────── + +def eval_val_legal_ttt(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=print): + seq_len = args.train_seq_len + stride = args.sliding_window_stride + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.legal_ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"legal_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"lr={args.legal_ttt_lr} epochs={args.legal_ttt_epochs} " + f"freeze_blocks={args.legal_ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(args.legal_ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"legal_ttt:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.legal_ttt_lr, momentum=args.legal_ttt_momentum) + batch_seqs = args.legal_ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.legal_ttt_epochs > 0: + base_model.train() + # Clear Rotary cache — inference tensors can't be used in backward + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.legal_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.legal_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.legal_ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" legal_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"legal_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +def prune_to_fit(result, meta, code_bytes, target_bytes=16_000_000, compress="lzma"): + """Selectively zero ±1 quantized values to fit artifact in budget.""" + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + candidates = [] + for name, info in meta.items(): + if isinstance(info, dict) and info.get("type") in ("int6", "int5"): + q = result[name + ".q"] + s = result[name + ".scale"] + for row in range(q.shape[0]): + mask = (q[row].abs() == 1) + if mask.any(): + scale_sq = float(s[row].float() ** 2) if s.ndim > 0 else float(s.float() ** 2) + count = int(mask.sum().item()) + candidates.append((scale_sq, name, row, count)) + + candidates.sort(key=lambda x: x[0]) + + batch_size = max(1, len(candidates) // 20) + for i in range(0, len(candidates), batch_size): + batch = candidates[i:i + batch_size] + for _, name, row, _ in batch: + q = result[name + ".q"] + mask = (q[row].abs() == 1) + q[row][mask] = 0 + + buf = io.BytesIO() + torch.save({"w": result, "m": meta}, buf) + raw = buf.getvalue() + if compress == "lzma": + blob = lzma.compress(raw, preset=6) + else: + blob = zlib.compress(raw, level=9) + if len(blob) + code_bytes <= target_bytes: + return result, len(blob) + + return result, len(blob) + + +# ───────────────────────────────────────────────────────────── +# TRAINING +# ───────────────────────────────────────────────────────────── + +def main() -> None: + + try: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + args.validate() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + + # ── Distributed + CUDA ── + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + logger.info(f"Log file: {logfile}") + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + logger.info(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + + # ── Tokenizer + Validation ── + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + # ── Model ── + base_model = YoctoGPT( + vocab_size=args.vocab_size, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_unique_layers=args.num_unique_layers, + num_recurrences=args.num_recurrences, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + seeking_gain_init=args.seeking_gain_init, + rope_fraction=args.rope_fraction, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + int5_layers=args.int5_layers, + ).to(device).bfloat16() + + # Banks stay FP32, cast to BF16 in forward + base_model.unified_bank.data = base_model.unified_bank.data.float() + base_model.output_bank.data = base_model.output_bank.data.float() + base_model.fc_bank.data = base_model.fc_bank.data.float() + base_model.proj_bank.data = base_model.proj_bank.data.float() + + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + + if master_process: + log_architecture(base_model, args) + + try: + _test_mod = torch.compile(lambda q, k, v: _flash_attn_func(q, k, v, causal=True), dynamic=False) + _tq = torch.randn(1, 8, 1, 48, dtype=torch.bfloat16, device=device) + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + _test_mod(_tq, _tq, _tq) + log0("torch.compile + FA3: COMPATIBLE") + compiled_model = torch.compile(base_model, dynamic=False) + model = compiled_model + except Exception as e: + log0(f"torch.compile + FA3: INCOMPATIBLE ({type(e).__name__}), running uncompiled") + model = base_model + + log0("attention_backend:fa3") + + # ── Optimizer: banks → Muon, rest → Adam/AdamW ── + matrix_params = [ + base_model.unified_bank, base_model.output_bank, + base_model.fc_bank, base_model.proj_bank, + ] + + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [p for name, p in block_named_params + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # SmearGate gate parameter → scalar (AdamW) + scalar_params.append(base_model.smear.gate) + # VE per-layer scales + shared scale → scalar (AdamW) + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + # VE projection → scalar (too small for Muon) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + # VE embedding → token optimizer (AdamW with token_lr) + if base_model.ve_shared is not None: + tok_param_groups.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW(tok_param_groups, + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # Non-bank params that need manual all-reduce + replicated_params = [base_model.tok_emb.weight] + scalar_params + # VE embedding also needs all-reduce + if base_model.ve_shared is not None: + replicated_params.append(base_model.ve_shared.embed.weight) + + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + replicated_params.append(base_model.lm_head.weight) + if base_model.bigram is not None: + bigram_params = list(base_model.bigram.parameters()) + optimizer_bigram = torch.optim.AdamW([{"params": bigram_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.muon_weight_decay, fused=True) + optimizers.append(optimizer_bigram) + replicated_params.extend(bigram_params) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} effective_depth:{args.num_effective_layers}") + if base_model.int5_layer_set: + log0(f"mixed_precision: int5_layers={sorted(base_model.int5_layer_set)} int6_layers={sorted(set(range(args.num_unique_layers)) - base_model.int5_layer_set)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + + # ── Data loader + warmup ── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + # ── EMA + SWA shadow weights ── + ema_state = None + swa_params = None + swa_count = 0 + if args.ema_decay > 0: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"EMA enabled: decay={args.ema_decay}") + if args.swa_every > 0: + swa_params = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"SWA enabled: every {args.swa_every} steps when lr_scale < {args.swa_threshold}") + + def update_ema_swa(step, lr_scale): + """EMA via state_dict() (matches SOTA, zero overhead vs named_parameters). + SWA: additive accumulation on CPU during late warmdown.""" + nonlocal swa_count + with torch.no_grad(): + if ema_state is not None: + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if swa_params is not None and step > 0 and step % args.swa_every == 0: + if lr_scale < args.swa_threshold: + if swa_count == 0: + for name, t in base_model.state_dict().items(): + swa_params[name].copy_(t.detach().cpu()) + swa_count = 1 + log0(f"SWA started at step {step} (lr_scale={lr_scale:.4f})") + else: + for name, t in base_model.state_dict().items(): + swa_params[name] += t.detach().cpu() + swa_count += 1 + + def get_best_weights(): + """Return best averaged weights. EMA preferred (per PR#401).""" + if ema_state is not None: + log0(f"Using EMA weights (decay={args.ema_decay})") + current_state = base_model.state_dict() + return {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + if swa_params is not None and swa_count >= 2: + log0(f"Using SWA weights ({swa_count} checkpoints)") + current_state = base_model.state_dict() + return {name: (t / swa_count).to(dtype=current_state[name].dtype) + for name, t in swa_params.items()} + return None + + def lr_mul(step, elapsed_ms): + # Phase 1: LR warmup (linear ramp from 0 to 1) + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + + # Phase 2: Warmdown + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup (compile paths) + if args.warmup_steps > 0: + initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + wl = model(x, y) + (wl * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if ws + 1 == args.warmup_steps or (ws + 1) % 10 == 0: + log0(f"warmup_step:{ws+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ── + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # ── QAT activation check ── + global _qat_active, _qat_bits + if args.qat_bits > 0 and not _qat_active: + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac = elapsed_ms / max_wallclock_ms + else: + frac = step / max(args.iterations, 1) + if frac >= args.qat_start_fraction: + _qat_active = True + _qat_bits = args.qat_bits + log0(f"QAT enabled: int{args.qat_bits} at step {step} (fraction={frac:.2f})") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + if opt is not optimizer_muon: + opt.step() + # Phase 3: Wait for RS, local NS5, all-gather + optimizer_muon.step() + + update_ema_swa(step, scale) + zero_grad_all() + + step += 1 + approx_time = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_time:.0f}ms step_avg:{approx_time / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_time >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rc = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rc, op=dist.ReduceOp.MAX) + reached_cap = bool(rc.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ── Load best averaged weights (EMA > SWA > raw) ── + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + + # ── Serialization ── + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Raw model: {model_bytes} bytes, code: {code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + + # Compress with selected algorithm + if args.compression == "lzma": + quant_blob = lzma.compress(quant_raw, preset=6) + compress_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + compressor = zstd_mod.ZstdCompressor(level=22) + quant_blob = compressor.compress(quant_raw) + compress_label = "zstd-22" + except ImportError: + logger.warning("zstandard not installed, falling back to zlib") + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"int8+{compress_label}: {quant_bytes} bytes, total: {quant_bytes + code_bytes} bytes") + if quant_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {quant_bytes + code_bytes} > 16,000,000") + + # ── Roundtrip validation ── + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + qblob = f.read() + if args.compression == "lzma": + qblob_decompressed = lzma.decompress(qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + decompressor = zstd_mod.ZstdDecompressor() + qblob_decompressed = decompressor.decompress(qblob) + except ImportError: + qblob_decompressed = zlib.decompress(qblob) + else: + qblob_decompressed = zlib.decompress(qblob) + base_model.load_state_dict(dequantize_state_dict_int8( + torch.load(io.BytesIO(qblob_decompressed), map_location="cpu")), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_int8_{compress_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + log0(f"final_int8_{compress_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # ── Mixed int6/int8 quantization + roundtrip (if QAT was used) ── + if args.qat_bits == 6: + # Reload bf16 model first (int8 roundtrip modified it) + if master_process: + base_model.load_state_dict(torch.load("final_model.pt", map_location="cpu"), strict=True) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + # Unbank 3D tensors for better per-row quantization + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_unique_layers) + int5_set = set(int(x) for x in args.int5_layers.split(",") if x.strip()) + mixed_result, mixed_meta = quantize_state_dict_mixed(unbanked_sd, int5_layers=int5_set) + code_bytes = len(code.encode("utf-8")) + mixed_result, _ = prune_to_fit(mixed_result, mixed_meta, code_bytes, + target_bytes=16_000_000, compress=args.compression) + mixed_buf = io.BytesIO() + torch.save({"w": mixed_result, "m": mixed_meta}, mixed_buf) + mixed_raw = mixed_buf.getvalue() + if args.compression == "lzma": + mixed_blob = lzma.compress(mixed_raw, preset=6) + mixed_label = "lzma-6" + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_blob = zstd_mod.ZstdCompressor(level=22).compress(mixed_raw) + mixed_label = "zstd-22" + except ImportError: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + else: + mixed_blob = zlib.compress(mixed_raw, level=9) + mixed_label = "zlib-9" + if master_process: + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + mixed_bytes = os.path.getsize("final_model.mixed.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"mixed_int6_int8+{mixed_label}: {mixed_bytes} bytes, total: {mixed_bytes + code_bytes} bytes") + if mixed_bytes + code_bytes > 16_000_000: + logger.warning(f"OVER BUDGET: {mixed_bytes + code_bytes} > 16,000,000") + else: + log0(f"FITS: {mixed_bytes + code_bytes} <= 16,000,000") + # Mixed roundtrip validation + if distributed: + dist.barrier() + with open("final_model.mixed.ptz", "rb") as f: + mixed_qblob = f.read() + if args.compression == "lzma": + mixed_decompressed = lzma.decompress(mixed_qblob) + elif args.compression == "zstd": + try: + import zstandard as zstd_mod + mixed_decompressed = zstd_mod.ZstdDecompressor().decompress(mixed_qblob) + except ImportError: + mixed_decompressed = zlib.decompress(mixed_qblob) + else: + mixed_decompressed = zlib.decompress(mixed_qblob) + quant_state = torch.load(io.BytesIO(mixed_decompressed), map_location="cpu") + deq_unbanked = dequantize_state_dict_mixed(quant_state["w"], quant_state["m"], unbanked_sd) + # Rebank individual 2D tensors back into 3D banks + deq_sd = _rebank_state_dict(deq_unbanked, args.num_unique_layers, sd_cpu) + base_model.load_state_dict(deq_sd, strict=True) + torch.cuda.synchronize() + qm_val_loss, qm_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_mixed_{mixed_label}_roundtrip val_loss:{qm_val_loss:.4f} val_bpb:{qm_val_bpb:.4f}") + log0(f"final_mixed_{mixed_label}_roundtrip_exact val_loss:{qm_val_loss:.8f} val_bpb:{qm_val_bpb:.8f}") + + # ── Sliding window eval (if enabled) ── + if args.sliding_window_eval: + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding_window( + args, model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms") + + # ── TTT LoRA eval ── + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_ttt_lora(args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"final_int8_ttt_lora val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + # ── Legal Score-First TTT eval ── + if args.legal_ttt_enabled: + # Reload EMA weights for clean start + best_weights = get_best_weights() + if best_weights is not None: + base_model.load_state_dict(best_weights, strict=True) + torch._dynamo.reset() + torch.cuda.synchronize() + t_legal = time.perf_counter() + legal_loss, legal_bpb = eval_val_legal_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log0=log0) + log0(f"final_legal_ttt val_loss:{legal_loss:.4f} val_bpb:{legal_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_legal):.0f}ms") + + if distributed: + dist.destroy_process_group() + + except Exception: + logger.error(f"FATAL ERROR:\n{traceback.format_exc()}") + raise + + +if __name__ == "__main__": + main() +==================================================================================================== +torch.compile + FA3: COMPATIBLE +attention_backend:fa3 +model_params:23209295 effective_depth:11 +world_size:8 grad_accum_steps:1 +EMA enabled: decay=0.997 +SWA enabled: every 50 steps when lr_scale < 0.2 +warmup_step:10/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9322 val_bpb:4.1056 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9336 train_time:88ms step_avg:88.42ms +step:2/20000 train_loss:6.8951 train_time:104ms step_avg:52.13ms +step:3/20000 train_loss:6.7893 train_time:150ms step_avg:50.10ms +step:4/20000 train_loss:6.5916 train_time:197ms step_avg:49.28ms +step:5/20000 train_loss:6.3097 train_time:244ms step_avg:48.89ms +step:6/20000 train_loss:6.1390 train_time:292ms step_avg:48.70ms +step:7/20000 train_loss:5.8967 train_time:340ms step_avg:48.54ms +step:8/20000 train_loss:5.8412 train_time:387ms step_avg:48.39ms +step:9/20000 train_loss:5.7845 train_time:435ms step_avg:48.29ms +step:10/20000 train_loss:5.7473 train_time:482ms step_avg:48.22ms +step:500/20000 train_loss:2.4780 train_time:24091ms step_avg:48.18ms +step:1000/20000 train_loss:2.3700 train_time:48358ms step_avg:48.36ms +step:1500/20000 train_loss:2.2271 train_time:72663ms step_avg:48.44ms +QAT enabled: int6 at step 1855 (fraction=0.15) +step:2000/20000 train_loss:2.2132 train_time:105216ms step_avg:52.61ms +step:2500/20000 train_loss:2.2041 train_time:129670ms step_avg:51.87ms +step:3000/20000 train_loss:3.1737 train_time:154050ms step_avg:51.35ms +step:3000/20000 val_loss:2.1691 val_bpb:1.2847 train_time:154086ms step_avg:51.36ms +step:3500/20000 train_loss:2.3497 train_time:178493ms step_avg:51.00ms +step:4000/20000 train_loss:2.2522 train_time:202892ms step_avg:50.72ms +step:4500/20000 train_loss:1.8575 train_time:227338ms step_avg:50.52ms +step:5000/20000 train_loss:2.2180 train_time:251801ms step_avg:50.36ms +step:5500/20000 train_loss:2.1981 train_time:276222ms step_avg:50.22ms +step:6000/20000 train_loss:2.0970 train_time:300676ms step_avg:50.11ms +step:6000/20000 val_loss:2.1283 val_bpb:1.2605 train_time:300711ms step_avg:50.12ms +step:6500/20000 train_loss:2.0931 train_time:325123ms step_avg:50.02ms +step:7000/20000 train_loss:2.0826 train_time:349538ms step_avg:49.93ms +step:7500/20000 train_loss:2.0961 train_time:373994ms step_avg:49.87ms +step:8000/20000 train_loss:2.0502 train_time:398403ms step_avg:49.80ms +step:8500/20000 train_loss:2.2085 train_time:422863ms step_avg:49.75ms +step:9000/20000 train_loss:2.0880 train_time:447320ms step_avg:49.70ms +step:9000/20000 val_loss:2.1000 val_bpb:1.2437 train_time:447355ms step_avg:49.71ms +step:9500/20000 train_loss:2.0181 train_time:471736ms step_avg:49.66ms +step:10000/20000 train_loss:1.9751 train_time:496192ms step_avg:49.62ms +step:10500/20000 train_loss:1.9844 train_time:520645ms step_avg:49.59ms +step:11000/20000 train_loss:2.0134 train_time:545051ms step_avg:49.55ms +SWA started at step 11450 (lr_scale=0.1897) +step:11500/20000 train_loss:1.8805 train_time:569591ms step_avg:49.53ms +step:12000/20000 train_loss:1.9484 train_time:594539ms step_avg:49.54ms +step:12000/20000 val_loss:1.9684 val_bpb:1.1658 train_time:594574ms step_avg:49.55ms +step:12109/20000 val_loss:1.9655 val_bpb:1.1641 train_time:600075ms step_avg:49.56ms +stopping_early: wallclock_cap train_time:600075ms step:12109/20000 +peak memory: 12569 MiB +Using EMA weights (decay=0.997) +Raw model: 91514419 bytes, code: 117180 bytes +int8+lzma-6: 16382876 bytes, total: 16500056 bytes +final_int8_lzma-6_roundtrip val_loss:1.9668 val_bpb:1.1649 +final_int8_lzma-6_roundtrip_exact val_loss:1.96683854 val_bpb:1.16487257 +mixed_int6_int8+lzma-6: 15871736 bytes, total: 15988916 bytes +FITS: 15988916 <= 16,000,000 +final_mixed_lzma-6_roundtrip val_loss:1.9666 val_bpb:1.1647 +final_mixed_lzma-6_roundtrip_exact val_loss:1.96660765 val_bpb:1.16473582 +final_int8_ttt_lora val_loss:1.9424 val_bpb:1.1504 eval_time:38631ms +Using EMA weights (decay=0.997) +legal_ttt:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 lr=0.002 epochs=3 freeze_blocks=0 +legal_ttt:params unfrozen=23209295 frozen=0 + legal_ttt_chunk [1/1893] bpb=1.168510 time=0.3s + legal_ttt_chunk [11/1893] bpb=1.162496 time=2.4s + legal_ttt_chunk [21/1893] bpb=1.148210 time=4.6s + legal_ttt_chunk [31/1893] bpb=1.147254 time=6.8s + legal_ttt_chunk [41/1893] bpb=1.134474 time=8.9s + legal_ttt_chunk [51/1893] bpb=1.129083 time=11.1s + legal_ttt_chunk [61/1893] bpb=1.136654 time=13.2s + legal_ttt_chunk [71/1893] bpb=1.135944 time=15.4s + legal_ttt_chunk [81/1893] bpb=1.135866 time=17.5s + legal_ttt_chunk [91/1893] bpb=1.136568 time=19.7s + legal_ttt_chunk [101/1893] bpb=1.140282 time=21.8s + legal_ttt_chunk [111/1893] bpb=1.142431 time=24.0s + legal_ttt_chunk [121/1893] bpb=1.135699 time=26.1s + legal_ttt_chunk [131/1893] bpb=1.136006 time=28.3s + legal_ttt_chunk [141/1893] bpb=1.141742 time=30.4s + legal_ttt_chunk [151/1893] bpb=1.143717 time=32.6s + legal_ttt_chunk [161/1893] bpb=1.143572 time=34.7s + legal_ttt_chunk [171/1893] bpb=1.147994 time=36.9s + legal_ttt_chunk [181/1893] bpb=1.150513 time=39.0s + legal_ttt_chunk [191/1893] bpb=1.157987 time=41.2s + legal_ttt_chunk [201/1893] bpb=1.156968 time=43.3s + legal_ttt_chunk [211/1893] bpb=1.154645 time=45.5s + legal_ttt_chunk [221/1893] bpb=1.156197 time=47.6s + legal_ttt_chunk [231/1893] bpb=1.154895 time=49.8s + legal_ttt_chunk [241/1893] bpb=1.155382 time=51.9s + legal_ttt_chunk [251/1893] bpb=1.154922 time=54.1s + legal_ttt_chunk [261/1893] bpb=1.151850 time=56.2s + legal_ttt_chunk [271/1893] bpb=1.150642 time=58.4s + legal_ttt_chunk [281/1893] bpb=1.151924 time=60.5s + legal_ttt_chunk [291/1893] bpb=1.153732 time=62.7s + legal_ttt_chunk [301/1893] bpb=1.154502 time=64.8s + legal_ttt_chunk [311/1893] bpb=1.156614 time=67.0s + legal_ttt_chunk [321/1893] bpb=1.158502 time=69.1s + legal_ttt_chunk [331/1893] bpb=1.158528 time=71.3s + legal_ttt_chunk [341/1893] bpb=1.157471 time=73.4s + legal_ttt_chunk [351/1893] bpb=1.159790 time=75.6s + legal_ttt_chunk [361/1893] bpb=1.160106 time=77.7s + legal_ttt_chunk [371/1893] bpb=1.159404 time=79.9s + legal_ttt_chunk [381/1893] bpb=1.159510 time=82.0s + legal_ttt_chunk [391/1893] bpb=1.159353 time=84.2s + legal_ttt_chunk [401/1893] bpb=1.157271 time=86.3s + legal_ttt_chunk [411/1893] bpb=1.156196 time=88.5s + legal_ttt_chunk [421/1893] bpb=1.155281 time=90.6s + legal_ttt_chunk [431/1893] bpb=1.155191 time=92.8s + legal_ttt_chunk [441/1893] bpb=1.155506 time=94.9s + legal_ttt_chunk [451/1893] bpb=1.155771 time=97.1s + legal_ttt_chunk [461/1893] bpb=1.154653 time=99.2s + legal_ttt_chunk [471/1893] bpb=1.155253 time=101.4s + legal_ttt_chunk [481/1893] bpb=1.154948 time=103.5s + legal_ttt_chunk [491/1893] bpb=1.153835 time=105.7s + legal_ttt_chunk [501/1893] bpb=1.153351 time=107.8s + legal_ttt_chunk [511/1893] bpb=1.152688 time=110.0s + legal_ttt_chunk [521/1893] bpb=1.150582 time=112.1s + legal_ttt_chunk [531/1893] bpb=1.151712 time=114.3s + legal_ttt_chunk [541/1893] bpb=1.152073 time=116.4s + legal_ttt_chunk [551/1893] bpb=1.150977 time=118.6s + legal_ttt_chunk [561/1893] bpb=1.151477 time=120.7s + legal_ttt_chunk [571/1893] bpb=1.150385 time=122.9s + legal_ttt_chunk [581/1893] bpb=1.149599 time=125.0s + legal_ttt_chunk [591/1893] bpb=1.148918 time=127.2s + legal_ttt_chunk [601/1893] bpb=1.149410 time=129.4s + legal_ttt_chunk [611/1893] bpb=1.149364 time=131.5s + legal_ttt_chunk [621/1893] bpb=1.149166 time=133.7s + legal_ttt_chunk [631/1893] bpb=1.149835 time=135.8s + legal_ttt_chunk [641/1893] bpb=1.149604 time=138.0s + legal_ttt_chunk [651/1893] bpb=1.149775 time=140.1s + legal_ttt_chunk [661/1893] bpb=1.149263 time=142.3s + legal_ttt_chunk [671/1893] bpb=1.149620 time=144.4s + legal_ttt_chunk [681/1893] bpb=1.150272 time=146.5s + legal_ttt_chunk [691/1893] bpb=1.151318 time=148.7s + legal_ttt_chunk [701/1893] bpb=1.150774 time=150.8s + legal_ttt_chunk [711/1893] bpb=1.150784 time=153.0s + legal_ttt_chunk [721/1893] bpb=1.150442 time=155.1s + legal_ttt_chunk [731/1893] bpb=1.150476 time=157.3s + legal_ttt_chunk [741/1893] bpb=1.150529 time=159.5s + legal_ttt_chunk [751/1893] bpb=1.150396 time=161.6s + legal_ttt_chunk [761/1893] bpb=1.150276 time=163.8s + legal_ttt_chunk [771/1893] bpb=1.149951 time=165.9s + legal_ttt_chunk [781/1893] bpb=1.150735 time=168.1s + legal_ttt_chunk [791/1893] bpb=1.150361 time=170.2s + legal_ttt_chunk [801/1893] bpb=1.150653 time=172.4s + legal_ttt_chunk [811/1893] bpb=1.150431 time=174.5s + legal_ttt_chunk [821/1893] bpb=1.150175 time=176.6s + legal_ttt_chunk [831/1893] bpb=1.149983 time=178.8s + legal_ttt_chunk [841/1893] bpb=1.149292 time=180.9s + legal_ttt_chunk [851/1893] bpb=1.149029 time=183.1s + legal_ttt_chunk [861/1893] bpb=1.148776 time=185.2s + legal_ttt_chunk [871/1893] bpb=1.149047 time=187.4s + legal_ttt_chunk [881/1893] bpb=1.149246 time=189.5s + legal_ttt_chunk [891/1893] bpb=1.148833 time=191.7s + legal_ttt_chunk [901/1893] bpb=1.148608 time=193.9s + legal_ttt_chunk [911/1893] bpb=1.148714 time=196.0s + legal_ttt_chunk [921/1893] bpb=1.149189 time=198.2s + legal_ttt_chunk [931/1893] bpb=1.149189 time=200.3s + legal_ttt_chunk [941/1893] bpb=1.148877 time=202.5s + legal_ttt_chunk [951/1893] bpb=1.149269 time=204.6s + legal_ttt_chunk [961/1893] bpb=1.149356 time=206.7s + legal_ttt_chunk [971/1893] bpb=1.150186 time=208.9s + legal_ttt_chunk [981/1893] bpb=1.150236 time=211.0s + legal_ttt_chunk [991/1893] bpb=1.150272 time=213.2s + legal_ttt_chunk [1001/1893] bpb=1.150241 time=215.3s + legal_ttt_chunk [1011/1893] bpb=1.150027 time=217.5s + legal_ttt_chunk [1021/1893] bpb=1.150362 time=219.6s + legal_ttt_chunk [1031/1893] bpb=1.150825 time=221.8s + legal_ttt_chunk [1041/1893] bpb=1.150457 time=223.9s + legal_ttt_chunk [1051/1893] bpb=1.150187 time=226.1s + legal_ttt_chunk [1061/1893] bpb=1.150204 time=228.3s + legal_ttt_chunk [1071/1893] bpb=1.150822 time=230.4s + legal_ttt_chunk [1081/1893] bpb=1.151080 time=232.6s + legal_ttt_chunk [1091/1893] bpb=1.151828 time=234.7s + legal_ttt_chunk [1101/1893] bpb=1.151836 time=236.9s + legal_ttt_chunk [1111/1893] bpb=1.151710 time=239.0s + legal_ttt_chunk [1121/1893] bpb=1.151524 time=241.2s + legal_ttt_chunk [1131/1893] bpb=1.151427 time=243.3s + legal_ttt_chunk [1141/1893] bpb=1.151126 time=245.5s + legal_ttt_chunk [1151/1893] bpb=1.151148 time=247.6s + legal_ttt_chunk [1161/1893] bpb=1.150816 time=249.7s + legal_ttt_chunk [1171/1893] bpb=1.151134 time=251.9s + legal_ttt_chunk [1181/1893] bpb=1.150385 time=254.0s + legal_ttt_chunk [1191/1893] bpb=1.150256 time=256.2s + legal_ttt_chunk [1201/1893] bpb=1.150664 time=258.4s + legal_ttt_chunk [1211/1893] bpb=1.150194 time=260.5s + legal_ttt_chunk [1221/1893] bpb=1.149869 time=262.7s + legal_ttt_chunk [1231/1893] bpb=1.149588 time=264.8s + legal_ttt_chunk [1241/1893] bpb=1.149246 time=267.0s + legal_ttt_chunk [1251/1893] bpb=1.148653 time=269.1s + legal_ttt_chunk [1261/1893] bpb=1.148647 time=271.3s + legal_ttt_chunk [1271/1893] bpb=1.148296 time=273.4s + legal_ttt_chunk [1281/1893] bpb=1.148122 time=275.6s + legal_ttt_chunk [1291/1893] bpb=1.147900 time=277.7s + legal_ttt_chunk [1301/1893] bpb=1.147299 time=279.9s + legal_ttt_chunk [1311/1893] bpb=1.146922 time=282.0s + legal_ttt_chunk [1321/1893] bpb=1.146588 time=284.2s + legal_ttt_chunk [1331/1893] bpb=1.146513 time=286.3s + legal_ttt_chunk [1341/1893] bpb=1.146400 time=288.5s + legal_ttt_chunk [1351/1893] bpb=1.146358 time=290.6s + legal_ttt_chunk [1361/1893] bpb=1.146407 time=292.8s + legal_ttt_chunk [1371/1893] bpb=1.146272 time=294.9s + legal_ttt_chunk [1381/1893] bpb=1.146258 time=297.1s + legal_ttt_chunk [1391/1893] bpb=1.145865 time=299.2s + legal_ttt_chunk [1401/1893] bpb=1.145845 time=301.4s + legal_ttt_chunk [1411/1893] bpb=1.145961 time=303.5s + legal_ttt_chunk [1421/1893] bpb=1.146212 time=305.7s + legal_ttt_chunk [1431/1893] bpb=1.145919 time=307.8s + legal_ttt_chunk [1441/1893] bpb=1.146443 time=310.0s + legal_ttt_chunk [1451/1893] bpb=1.146784 time=312.1s + legal_ttt_chunk [1461/1893] bpb=1.146344 time=314.3s + legal_ttt_chunk [1471/1893] bpb=1.147391 time=316.4s + legal_ttt_chunk [1481/1893] bpb=1.146935 time=318.6s + legal_ttt_chunk [1491/1893] bpb=1.146734 time=320.7s + legal_ttt_chunk [1501/1893] bpb=1.146688 time=322.9s + legal_ttt_chunk [1511/1893] bpb=1.146708 time=325.0s + legal_ttt_chunk [1521/1893] bpb=1.146768 time=327.2s + legal_ttt_chunk [1531/1893] bpb=1.146258 time=329.3s + legal_ttt_chunk [1541/1893] bpb=1.146122 time=331.5s + legal_ttt_chunk [1551/1893] bpb=1.146422 time=333.6s + legal_ttt_chunk [1561/1893] bpb=1.146412 time=335.8s + legal_ttt_chunk [1571/1893] bpb=1.146243 time=337.9s + legal_ttt_chunk [1581/1893] bpb=1.146380 time=340.1s + legal_ttt_chunk [1591/1893] bpb=1.146223 time=342.2s + legal_ttt_chunk [1601/1893] bpb=1.146401 time=344.4s + legal_ttt_chunk [1611/1893] bpb=1.146317 time=346.5s + legal_ttt_chunk [1621/1893] bpb=1.145942 time=348.7s + legal_ttt_chunk [1631/1893] bpb=1.146268 time=350.8s + legal_ttt_chunk [1641/1893] bpb=1.146300 time=353.0s + legal_ttt_chunk [1651/1893] bpb=1.146248 time=355.1s + legal_ttt_chunk [1661/1893] bpb=1.146126 time=357.3s + legal_ttt_chunk [1671/1893] bpb=1.146609 time=359.4s + legal_ttt_chunk [1681/1893] bpb=1.146763 time=361.6s + legal_ttt_chunk [1691/1893] bpb=1.146595 time=363.7s + legal_ttt_chunk [1701/1893] bpb=1.146722 time=365.9s + legal_ttt_chunk [1711/1893] bpb=1.146691 time=368.0s + legal_ttt_chunk [1721/1893] bpb=1.146676 time=370.2s + legal_ttt_chunk [1731/1893] bpb=1.146560 time=372.3s + legal_ttt_chunk [1741/1893] bpb=1.146340 time=374.5s + legal_ttt_chunk [1751/1893] bpb=1.146165 time=376.6s + legal_ttt_chunk [1761/1893] bpb=1.146304 time=378.7s + legal_ttt_chunk [1771/1893] bpb=1.146207 time=380.9s + legal_ttt_chunk [1781/1893] bpb=1.146239 time=383.0s + legal_ttt_chunk [1791/1893] bpb=1.145833 time=385.2s + legal_ttt_chunk [1801/1893] bpb=1.145704 time=387.3s + legal_ttt_chunk [1811/1893] bpb=1.145615 time=389.5s + legal_ttt_chunk [1821/1893] bpb=1.145672 time=391.6s + legal_ttt_chunk [1831/1893] bpb=1.145070 time=393.8s + legal_ttt_chunk [1841/1893] bpb=1.145184 time=395.9s + legal_ttt_chunk [1851/1893] bpb=1.144967 time=398.1s + legal_ttt_chunk [1861/1893] bpb=1.144593 time=400.2s + legal_ttt_chunk [1871/1893] bpb=1.144569 time=402.4s + legal_ttt_chunk [1881/1893] bpb=1.144147 time=404.5s + legal_ttt_chunk [1891/1893] bpb=1.143920 time=406.7s + legal_ttt_chunk [1893/1893] bpb=1.143963 time=407.0s +legal_ttt:done val_loss=1.927601 val_bpb=1.141637 elapsed=407.0s +final_legal_ttt val_loss:1.9276 val_bpb:1.1416 eval_time:407430ms