diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/README.md b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/README.md new file mode 100644 index 0000000000..379f466222 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/README.md @@ -0,0 +1,63 @@ +# FreqGPTQ + GatedDeltaNet + Adaptive Quantization + +**val_bpb: TBD** (pending GPU validation) | **~15.8 MB** | 8xH100 SXM + +## Approach + +Built on PR #1698 (GatedDeltaNet + Legal TTT, 1.00995 BPB) with quantization and compression improvements: + +### 1. FreqGPTQ (frequency-weighted GPTQ calibration) +Top-100 most frequent tokens get `sqrt(2)` boosted activations during Hessian accumulation (`H = X^T X`), biasing quantization error minimization toward high-frequency tokens covering ~53% of text. Zero artifact cost. + +### 2. PassthroughQuant +Control tensors (`attn_scale`, `mlp_scale`, `resid_mix`, `skip_weights`) quantized to per-tensor int8 instead of fp16 passthrough. Small 2D matrices also get per-row int8. Saves ~40KB compressed. + +### 3. Sandwich Quantization +Final transformer block quantized to int8 instead of int6 to protect signal quality before the tied LM head. + +### 4. Adaptive Embedding Precision +Int8 for top-100 frequent token embedding rows, int6/int5 for the rest. Higher precision where it matters most (Zipf's law). + +### 5. Configurable Int5 GPTQ +Weight quantization pushed from 6-bit (`clip_range=31`) to 5-bit (`clip_range=15`) with FreqGPTQ, fitting ~38M params vs ~32M at int6 within the same 16MB budget. Late QAT clip range synced to match. + +### 6. LZMA Self-Extracting Code Wrapper +Python source compressed from ~105KB to ~30KB via LZMA + base85 encoding. Frees ~73KB for model weights (~118K more parameters at int5). + +## Architecture + +Same as PR #1698 (Model K: K_KVShare_Wider): +- 10 GatedDeltaNet layers (FLA `fla-core==0.4.2`) +- 544 model dim, 8 heads, 64-dim head keys +- KV sharing stride 2 +- 3x MLP, BigramHash embedding, SmearGate +- Tied embeddings, logit softcap 30.0 + +## Training + +- Muon optimizer + EMA(0.997) + SWA(50) +- Late QAT with STE (threshold 0.15, clip range matches WEIGHT_BITS) +- Score-first TTT: SGD(lr=0.005, momentum=0.9), 3 epochs, 32K chunks, freeze first 2 blocks + +## Run Command + +```bash +ARCH_MODE=K GPTQ_ENABLED=1 TTT_ENABLED=1 WEIGHT_BITS=6 \ +FREQ_GPTQ_BOOST=2.0 ADAPTIVE_EMBED=1 NUM_FREQ_TOKENS=100 \ +TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 TTT_FREEZE_BLOCKS=2 \ +TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Status + +WIP — code complete, pending GPU validation for BPB results and 3-seed statistical significance. + +## Attribution + +- GatedDeltaNet architecture: @resouer (PR #1687), @arsenis-cmd (PR #1698) +- Flash Linear Attention: @sustcsonglin (fla-core 0.4.2) +- Legal TTT framework: @Christopher-Lee-McClendon (PR #461) +- FreqGPTQ concept: PR #1707 +- PassthroughQuant concept: PR #1716 +- AttnOutGate + SmearGate: PR #1693 diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/architectures.py b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/architectures.py new file mode 100644 index 0000000000..ad7584b3f8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/architectures.py @@ -0,0 +1,712 @@ +"""GDN Hybrid Architecture — modular blocks using FLA native layers. + +Supports 8 model variants (A-H) for the Parameter Golf screening experiments. +Each model is a stack of mixed {GDN, DeltaProduct, RWKV-7, Mamba-2, SWA} blocks +with shared MLP, RMSNorm, and residual connections. + +Key design choices: +- FLA layers handle recurrent attention (GatedDeltaNet, GatedDeltaProduct, RWKV7, Mamba2) +- Sliding Window Attention (SWA) uses flash attention with a causal window mask +- All blocks follow the same pre-norm residual pattern for uniform gradient flow +- Weight sharing for SWA layers in Zamba/Hymba-style models +- Score-first eval: XSA-all only extends attention layers (no future context leakage) +""" +from __future__ import annotations +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# ─── FLA backend selection ────────────────────────────────────────────────── +# Set FLA_USE_NAIVE=1 to force pure-PyTorch (naive) kernels instead of Triton. +# This is needed when: +# - Running on V100 (sm_70) which doesn't support FLA's Triton kernels well +# - Triton cache is corrupted (FileNotFoundError on .json files) +# - Debugging without Triton dependency +# +# On A100 (sm_80+), the Triton kernels are ~3-10x faster and should be used. +_USE_NAIVE = os.environ.get("FLA_USE_NAIVE", "0") == "1" + +if _USE_NAIVE: + # 1. Patch GatedDeltaNet's chunk op + import fla.ops.gated_delta_rule.chunk as _gdr_chunk + import fla.ops.gated_delta_rule.naive as _gdr_naive + + def _patched_chunk_gated_delta_rule( + q, k, v, g, beta, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdr_naive.naive_chunk_gated_delta_rule( + q, k, v, g, beta, + chunk_size=64, scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + _gdr_chunk.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + import fla.layers.gated_deltanet as _gdn_layer + _gdn_layer.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + + # 2. Patch GatedDeltaProduct's chunk op + import fla.ops.gated_delta_product.chunk as _gdp_chunk + import fla.ops.gated_delta_product.naive as _gdp_naive + + def _patched_chunk_gated_delta_product( + q, k, v, g, beta, num_householder=1, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdp_naive.naive_recurrent_gated_delta_product( + q, k, v, g, beta, + scale=scale, cu_seqlens=None, + initial_state=initial_state, + output_final_state=output_final_state, + num_householder=num_householder, + ) + + _gdp_chunk.chunk_gated_delta_product = _patched_chunk_gated_delta_product + import fla.layers.gated_deltaproduct as _gdp_layer + _gdp_layer.chunk_gated_delta_product = _patched_chunk_gated_delta_product + + print("[FLA] Using NAIVE (pure-PyTorch) kernels — set FLA_USE_NAIVE=0 for Triton", flush=True) + +# FLA imports +from fla.layers import GatedDeltaNet, GatedDeltaProduct, Mamba2 +try: + from fla.layers import RWKV7Attention +except Exception: + RWKV7Attention = None # type: ignore + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=False, window_size=(-1, -1)): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + if k2.size(1) != q2.size(1): + rep = q2.size(1) // k2.size(1) + k2 = k2.repeat_interleave(rep, dim=1) + v2 = v2.repeat_interleave(rep, dim=1) + out = torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal) + return out.transpose(1, 2) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + """Linear layer that casts input to weight dtype for mixed precision. + Supports late QAT (intN STE) when _qat_enabled is set. + _qat_clip_range controls the quantization range (15 for int5, 31 for int6).""" + _qat_enabled: bool = False + _qat_clip_range: int = 31 # default int6; set to 15 for int5 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(dtype=x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + cr = float(CastedLinear._qat_clip_range) + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / cr).clamp_min(1.0 / cr) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -cr, cr) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: forward uses quantized, backward uses full + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + """RoPE embeddings for sliding window attention.""" + def __init__(self, dim: int, base: float = 10000.0, max_len: int = 4096): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_len = max_len + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE to the input tensor.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + out1 = x1 * cos[:x.shape[-2]] - x2 * sin[:x.shape[-2]] + out2 = x2 * cos[:x.shape[-2]] + x1 * sin[:x.shape[-2]] + return torch.cat([out1, out2], dim=-1) + + +class MLP(nn.Module): + """Feed-forward MLP with configurable activation.""" + def __init__(self, dim: int, mult: float = 3.0, act: str = "relu_sq", leaky_slope: float = 0.5): + super().__init__() + hidden = int(mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + nn.init.zeros_(self.proj.weight) # zero-init output for residual + self.act = act + self.leaky_slope = leaky_slope + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.act == "leaky_relu_sq": + x = F.leaky_relu(x, negative_slope=self.leaky_slope) + else: + x = F.relu(x) + return self.proj(x.square()) + + +class SlidingWindowAttention(nn.Module): + """Sliding window causal attention for hybrid models. + + Supports XSA (cross-segment attention) at eval time for extending context + across eval chunks. Window is enforced during training but can be relaxed at eval. + KV can be shared across layers (Zamba-style) by reusing the same module. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + num_kv_heads: int = 4, + window_size: int = 512, + rope_base: float = 10000.0, + qk_gain_init: float = 1.5, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.window_size = window_size + + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + nn.init.zeros_(self.proj.weight) + + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False # enabled at eval time for XSA-all + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(B, T, self.num_kv_heads, self.head_dim) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + + if q.is_cuda and q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) + + # Use window during training, full causal at eval if XSA enabled + y = flash_attn_3_func(q, k, v, causal=True) + + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(B, T, D) + return self.proj(y) + + +class RecurrentBlock(nn.Module): + """Wraps any FLA recurrent layer (GDN, DeltaProduct, RWKV-7, Mamba-2) with + pre-norm residual connection and MLP.""" + + def __init__( + self, + dim: int, + recurrent_layer: nn.Module, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.recurrent = recurrent_layer + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # FLA layers return (output, state) or just output depending on mode + recurrent_out = self.recurrent(self.attn_norm(x_in)) + if isinstance(recurrent_out, tuple): + recurrent_out = recurrent_out[0] + + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * recurrent_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class AttentionBlock(nn.Module): + """SWA block with pre-norm residual and MLP.""" + + def __init__( + self, + dim: int, + swa: SlidingWindowAttention, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.attn = swa + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in), v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class SmearGate(nn.Module): + """Weighted average of current and previous token embeddings.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram/trigram embedding for additional context.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def _parse_layout(layout_str: str) -> list[tuple[str, int]]: + """Parse a layout string into a sequence of (layer_type, count) pairs. + + Examples: + "gdn_only" -> [("gdn", 11)] (count filled in by caller) + "gdn5_swa_gdn5_swa_shared" -> [("gdn", 5), ("swa", 1), ("gdn", 5), ("swa_shared", 1)] + "gdn3_swa_gdn3_swa_shared_gdn3" -> [("gdn", 3), ("swa", 1), ("gdn", 3), ("swa_shared", 1), ("gdn", 3)] + "mamba_only" -> [("mamba", 11)] + "gdn3_mamba2_swa_gdn3_mamba2" -> [("gdn", 3), ("mamba", 2), ("swa", 1), ("gdn", 3), ("mamba", 2)] + """ + if layout_str == "gdn_only": + return [("gdn", -1)] # -1 = use num_gdn_layers + if layout_str == "mamba_only": + return [("mamba", -1)] # -1 = use num_mamba_layers + if layout_str == "swa_only": + return [("swa", -1)] # -1 = use num_swa_layers + + # Parse custom layouts like "gdn5_swa_gdn5_swa_shared" + parts = layout_str.split("_") + result = [] + i = 0 + while i < len(parts): + part = parts[i] + if part.startswith("gdn") and len(part) > 3: + count = int(part[3:]) + result.append(("gdn", count)) + elif part.startswith("mamba") and len(part) > 5: + count = int(part[5:]) + result.append(("mamba", count)) + elif part == "swa": + # Check if next token is "shared" + if i + 1 < len(parts) and parts[i + 1] == "shared": + result.append(("swa_shared", 1)) + i += 1 + else: + result.append(("swa", 1)) + elif part == "shared": + # Already consumed by swa check above + pass + i += 1 + return result + + +class HybridGDN(nn.Module): + """Hybrid GDN architecture supporting mixed recurrent/attention layers. + + Builds a stack of blocks according to the layer_layout specification: + - "gdn" blocks use GatedDeltaNet (or GatedDeltaProduct, or RWKV-7) + - "mamba" blocks use Mamba-2 + - "swa" blocks use SlidingWindowAttention + - "swa_shared" reuses the same SWA module (Zamba-style weight sharing) + + All models share: token embedding, bigram hash, smear gate, final norm, lm_head. + """ + def __init__(self, config: dict, vocab_size: int = 1024): + super().__init__() + dim = config["model_dim"] + num_heads = config["num_heads"] + mlp_mult = config["mlp_mult"] + self.arch_name = config["arch_name"] + self.model_dim = dim + self.vocab_size = vocab_size + self.logit_softcap = 30.0 + + # Embeddings + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + self.bigram = BigramHashEmbedding( + config.get("bigram_vocab_size", 2048), + config.get("bigram_dim", 128), + dim, + trigram=config.get("trigram", False), + ) + self.smear = SmearGate(dim) + + # Meta tokens (Hymba-style, for Model E) + n_meta = config.get("meta_tokens", 0) + if n_meta > 0: + self.meta_tokens = nn.Parameter(torch.randn(1, n_meta, dim) * 0.02) + self.n_meta = n_meta + else: + self.meta_tokens = None + self.n_meta = 0 + + # Build layer stack + layout = _parse_layout(config["layer_layout"]) + self.blocks = nn.ModuleList() + self._block_types = [] # track type for XSA/diagnostics + self._shared_swa = None # shared SWA module for Zamba/Hymba models + + layer_idx = 0 + for layer_type, count in layout: + if count == -1: + # Fill with the specified layer type + if layer_type == "gdn": + count = config["num_gdn_layers"] + elif layer_type == "mamba": + count = config["num_mamba_layers"] + elif layer_type == "swa": + count = config["num_swa_layers"] + + for _ in range(count): + if layer_type == "gdn": + recurrent = self._make_recurrent_layer(config, layer_idx) + block = RecurrentBlock(dim, recurrent, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("gdn") + + elif layer_type == "mamba": + mamba_expand = config.get("mamba_expand", 2) + mamba_head_dim = config.get("gdn_head_dim", 64) + mamba_num_heads = (dim * mamba_expand) // mamba_head_dim + mamba = Mamba2( + num_heads=mamba_num_heads, + head_dim=mamba_head_dim, + hidden_size=dim, + state_size=config.get("mamba_state_size", 64), + expand=mamba_expand, + layer_idx=layer_idx, + ) + block = RecurrentBlock(dim, mamba, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("mamba") + + elif layer_type in ("swa", "swa_shared"): + if layer_type == "swa_shared" and self._shared_swa is not None: + swa = self._shared_swa # reuse same SWA module + else: + swa = SlidingWindowAttention( + dim=dim, + num_heads=num_heads, + num_kv_heads=config.get("swa_num_kv_heads", 4), + window_size=config.get("swa_window", 512), + ) + if config.get("swa_shared", False): + self._shared_swa = swa + + # Each SWA position gets its own MLP even if SWA weights are shared + block = AttentionBlock(dim, swa, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("swa" if layer_type == "swa" else "swa_shared") + + layer_idx += 1 + + # KV sharing: share k/v projections between adjacent layers + kv_stride = config.get("kv_sharing_stride", 0) + if kv_stride > 0: + self._apply_kv_sharing(kv_stride) + + self.final_norm = RMSNorm(dim) + # Tied embeddings (standard for parameter golf) + self.lm_head = None # use tok_emb.weight + self._init_weights() + + def _make_recurrent_layer(self, config: dict, layer_idx: int) -> nn.Module: + """Create the appropriate recurrent layer based on config.""" + dim = config["model_dim"] + num_heads = config["num_heads"] + + if config.get("use_rwkv7", False): + total_layers = config.get("num_gdn_layers", 11) + return RWKV7Attention( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + layer_idx=layer_idx, + num_hidden_layers=total_layers, + mode="chunk", + ) + elif config.get("use_deltaproduct", False): + return GatedDeltaProduct( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + num_householder=config.get("dp_num_householder", 2), + allow_neg_eigval=config.get("dp_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + else: + # Default: GatedDeltaNet + return GatedDeltaNet( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + allow_neg_eigval=config.get("gdn_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + + def _apply_kv_sharing(self, stride: int) -> None: + """Share KV projection modules between adjacent layer groups. + + For GDN layers: shares k_proj, v_proj, k_conv1d, v_conv1d. + For SWA layers: shares c_k, c_v. + The first layer in each group is the anchor; subsequent layers in the + group become followers that reference the anchor's modules. + """ + # Collect indices by block type + gdn_indices = [i for i, t in enumerate(self._block_types) if t == "gdn"] + swa_indices = [i for i, t in enumerate(self._block_types) + if t in ("swa", "swa_shared")] + + # Share GDN KV projections within each stride-group + for group_start in range(0, len(gdn_indices), stride): + anchor_idx = gdn_indices[group_start] + anchor = self.blocks[anchor_idx].recurrent + for j in range(1, stride): + if group_start + j >= len(gdn_indices): + break + follower_idx = gdn_indices[group_start + j] + follower = self.blocks[follower_idx].recurrent + follower.k_proj = anchor.k_proj + follower.v_proj = anchor.v_proj + follower.k_conv1d = anchor.k_conv1d + follower.v_conv1d = anchor.v_conv1d + + # Share SWA KV projections within each stride-group + for group_start in range(0, len(swa_indices), stride): + anchor_idx = swa_indices[group_start] + anchor = self.blocks[anchor_idx].attn + for j in range(1, stride): + if group_start + j >= len(swa_indices): + break + follower_idx = swa_indices[group_start + j] + follower = self.blocks[follower_idx].attn + follower.c_k = anchor.c_k + follower.c_v = anchor.c_v + + def _init_weights(self) -> None: + """Weight initialization. + + Each sub-module handles its own init (MLP zeros proj, SWA zeros proj, + FLA layers do own init). We just do the residual scaling for output + projections on our own CastedLinear layers. + """ + total_layers = len(self.blocks) + for name, p in self.named_parameters(): + # Skip FLA-internal parameters + if ".recurrent." in name: + continue + # Scale down output projections for residual stream + if p.ndim == 2 and "proj" in name and "bigram" not in name: + with torch.no_grad(): + p.mul_(1.0 / math.sqrt(2 * total_layers)) + + def set_xsa(self, enable: bool = True) -> None: + """Enable/disable XSA on all attention blocks.""" + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + block.attn.use_xsa = enable + + def _compute_logits(self, x: Tensor) -> Tensor: + """Compute logits with tied embeddings and softcap.""" + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Forward pass returning cross-entropy loss.""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + # Prepend meta tokens if Hymba-style + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + # Remove meta tokens before computing logits + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + logits = self._compute_logits(x.reshape(-1, x.size(-1))) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits (for evaluation).""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + return self._compute_logits(x) + + def get_diagnostics(self) -> dict: + """Collect per-layer weight statistics for checkpoint diagnostics.""" + diag = {} + for i, (block, btype) in enumerate(zip(self.blocks, self._block_types)): + prefix = f"layer_{i}_{btype}" + for name, param in block.named_parameters(): + if param.ndim >= 2: + w = param.data.float() + diag[f"{prefix}/{name}/std"] = w.std().item() + diag[f"{prefix}/{name}/kurtosis"] = (((w - w.mean()) / (w.std() + 1e-8)) ** 4).mean().item() - 3.0 + return diag + + def count_params(self) -> dict: + """Count parameters by category.""" + cats = {"embedding": 0, "recurrent": 0, "attention": 0, "mlp": 0, "other": 0} + for name, p in self.named_parameters(): + n = p.numel() + if "tok_emb" in name or "bigram" in name: + cats["embedding"] += n + elif any(k in name for k in ["recurrent", "gdn", "mamba", "rwkv", "delta"]): + cats["recurrent"] += n + elif "attn" in name or "c_q" in name or "c_k" in name or "c_v" in name: + cats["attention"] += n + elif "mlp" in name or "fc" in name: + cats["mlp"] += n + else: + cats["other"] += n + cats["total"] = sum(cats.values()) + return cats diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/configs.py b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/configs.py new file mode 100644 index 0000000000..5bbdac3bd4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/configs.py @@ -0,0 +1,316 @@ +"""Model architecture configurations for GDN Hybrid experiments. + +Each config returns a dict consumed by HybridGDN.__init__. +All models are sized to fit ~16MB at int6+zstd-22. + +Models A-H: baseline architecture sweeps. +Models I-K: KV sharing experiments (kv_sharing_stride=2). +""" +from __future__ import annotations + + +def model_a_pure_gdn() -> dict: + """Model A: Pure GDN (Baseline) — 10 layers Gated DeltaNet.""" + return dict( + arch_name="A_PureGDN", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_b_deltaproduct() -> dict: + """Model B: Gated DeltaProduct n_h=2 — rank-2 state transitions.""" + return dict( + arch_name="B_DeltaProduct", + num_gdn_layers=10, # 10 layers to fit param budget + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, # slightly narrower to fit 16MB + num_heads=8, + mlp_mult=3.0, + use_deltaproduct=True, + dp_num_householder=2, + dp_allow_neg_eigval=False, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_b2_deltaproduct_neg() -> dict: + """Model B2: DeltaProduct + negative eigenvalues.""" + cfg = model_b_deltaproduct() + cfg["arch_name"] = "B2_DeltaProduct_NegEig" + cfg["dp_allow_neg_eigval"] = True + return cfg + + +def model_c_gdn_neg() -> dict: + """Model C: GDN with negative eigenvalues — richer state dynamics. + + (Originally RWKV-7, replaced because RWKV7 requires Triton kernels with + no pure-PyTorch fallback available.) + """ + return dict( + arch_name="C_GDN_NegEig", + num_gdn_layers=11, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=True, # Key difference: negative eigenvalues + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_d_gdn_1swa() -> dict: + """Model D: GDN + 1 Shared SWA (Zamba-style).""" + return dict( + arch_name="D_GDN_1SWA", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=1, + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + layer_layout="gdn5_swa_gdn5_swa_shared", + ) + + +def model_e_gdn_2swa() -> dict: + """Model E: GDN + 2 Shared SWA (Hymba-inspired) with meta-tokens.""" + return dict( + arch_name="E_GDN_2SWA_Hymba", + num_gdn_layers=9, + num_mamba_layers=0, + num_swa_layers=1, # 1 unique, shared at 2 positions + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=4, # Hymba-style prepended meta-tokens + # Layout: [GDN×3] → [SWA] → [GDN×3] → [SWA_shared] → [GDN×3] + layer_layout="gdn3_swa_gdn3_swa_shared_gdn3", + ) + + +def model_f_mamba2() -> dict: + """Model F: Mamba-2 Pure (Mamba-3 proxy with RoPE on B/C).""" + return dict( + arch_name="F_Mamba2", + num_gdn_layers=0, + num_mamba_layers=11, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="mamba_only", + ) + + +def model_g_hybrid() -> dict: + """Model G: GDN + Mamba-2 + SWA triple hybrid.""" + return dict( + arch_name="G_GDN_Mamba_SWA", + num_gdn_layers=6, + num_mamba_layers=4, + num_swa_layers=1, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×3] → [Mamba×2] → [SWA] → [GDN×3] → [Mamba×2] + layer_layout="gdn3_mamba2_swa_gdn3_mamba2", + ) + + +def model_h_pure_swa() -> dict: + """Model H: Pure Sliding Window Attention (standard softmax) — control baseline. + + All 10 layers use causal sliding-window softmax attention (no GDN). + Same MLP, embedding, and normalization as Model A for fair comparison. + """ + return dict( + arch_name="H_PureSWA", + num_gdn_layers=0, + num_mamba_layers=0, + num_swa_layers=10, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="swa_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_i_kv_share() -> dict: + """Model I: GDN + KV Share — same as A but with kv_sharing_stride=2.""" + return dict( + arch_name="I_KVShare", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_j_kv_share_deeper() -> dict: + """Model J: GDN + KV Share + Deeper — 12L dim=480, near iso-parameter to A.""" + return dict( + arch_name="J_KVShare_Deeper", + num_gdn_layers=12, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_k_kv_share_wider() -> dict: + """Model K: GDN + KV Share + Wider — 10L dim=544, iso-parameter to A.""" + return dict( + arch_name="K_KVShare_Wider", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=544, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +ALL_CONFIGS = { + "A": model_a_pure_gdn, + "B": model_b_deltaproduct, + "B2": model_b2_deltaproduct_neg, + "C": model_c_gdn_neg, + "D": model_d_gdn_1swa, + "E": model_e_gdn_2swa, + "F": model_f_mamba2, + "G": model_g_hybrid, + "H": model_h_pure_swa, + "I": model_i_kv_share, + "J": model_j_kv_share_deeper, + "K": model_k_kv_share_wider, +} + + +def get_config(model_id: str) -> dict: + """Get config by model ID (A, B, B2, C, D, E, F, G, H, I, J, K).""" + if model_id not in ALL_CONFIGS: + raise ValueError(f"Unknown model ID '{model_id}'. Choose from {list(ALL_CONFIGS.keys())}") + return ALL_CONFIGS[model_id]() diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/requirements.txt b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/requirements.txt new file mode 100644 index 0000000000..3feaed4f64 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/requirements.txt @@ -0,0 +1,10 @@ +numpy +torch +sentencepiece +zstandard +flash-linear-attention==0.4.2 +fla-core==0.4.2 +triton==3.2.0 +transformers==5.5.4 +tokenizers==0.22.2 +safetensors==0.7.0 diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/submission.json b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/submission.json new file mode 100644 index 0000000000..1bc5a08c63 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/submission.json @@ -0,0 +1,21 @@ +{ + "author": "OlesStankevych", + "github_id": "OlesStankevych", + "name": "FreqGPTQ + GatedDeltaNet + Adaptive Quantization", + "date": "2026-04-18", + "track": "10min_16mb", + "val_bpb": null, + "val_bpb_std": null, + "seeds": [], + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "GatedDeltaNet (FLA) + FreqGPTQ + PassthroughQuant + Sandwich Quant + Adaptive Embeddings + Int5 GPTQ + LZMA wrapper + Legal Score-First TTT", + "status": "WIP - pending GPU validation", + "compliance": { + "train_under_600s": null, + "artifact_under_16mb": null, + "eval_under_600s": null, + "score_first_ttt": true, + "three_seed_validation": null + } +} diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gdn_7k.py b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gdn_7k.py new file mode 100644 index 0000000000..6ec5ad2ea5 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gdn_7k.py @@ -0,0 +1,1499 @@ +#!/usr/bin/env python3 +"""GDN Hybrid Full Training Script — 7000 steps with all production features. + +Features beyond Phase 1 screening (train_gdn.py): + - EMA weight averaging (decay 0.997) + - SWA (stochastic weight averaging) during late warmdown + - Late QAT (int6 STE in CastedLinear forward during warmdown) + - Mixed int6/int8 quantization with percentile search + - zstd-22 compression for artifact + - Coprime shard ordering via SHARD_ORDER_FILE + - XSA-all sliding window eval on quantized artifact + - Roundtrip validation (load quantized, eval, report exact BPB) + +Environment variables (key additions vs Phase 1): + EMA_DECAY: EMA decay rate (default 0.997) + SWA_ENABLED: 1|0 (default 1) + SWA_EVERY: SWA collection interval (default 50) + LATE_QAT_THRESHOLD: LR scale below which QAT activates (default 0.15) + SHARD_ORDER_FILE: path to file with one shard path per line (coprime ordering) + MUON_MOMENTUM_WARMUP_START: starting momentum for warmup (default 0.85) + MUON_MOMENTUM_WARMUP_STEPS: steps to ramp momentum (default 500) +""" +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import zstandard +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from architectures import HybridGDN, CastedLinear +from configs import get_config + + +# ─── Hyperparameters ────────────────────────────────────────────────────────── + +class Hyperparameters: + arch_mode = os.environ.get("ARCH_MODE", "A") + data_path = os.environ.get("DATA_PATH", "../data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "../data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + + # Training length + iterations = int(os.environ.get("ITERATIONS", 7000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2100)) # 30% of 7k + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 14100.0)) # 3h55m safety + + # Validation + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + save_every = int(os.environ.get("SAVE_EVERY", 1000)) + + # Optimizer + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + + # Eval + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + xsa_eval = bool(int(os.environ.get("XSA_EVAL", "0"))) # during training + eval_compile_enabled = bool(int(os.environ.get("EVAL_COMPILE_ENABLED", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Checkpoint + ckpt_dir = os.environ.get("CKPT_DIR", "checkpoints") + + # Compile + compile_enabled = bool(int(os.environ.get("COMPILE_ENABLED", "1"))) + + # Resume from checkpoint + resume_ckpt = os.environ.get("RESUME_CKPT", "") + + # EMA / SWA + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # TTT (Test-Time Training) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.005)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # Chained job support + auto_save_seconds = float(os.environ.get("AUTO_SAVE_SECONDS", "0")) + total_iterations = int(os.environ.get("TOTAL_ITERATIONS", "0")) # 0 = same as iterations + + +# ─── Data Loading ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=np.uint32, count=256) + assert header[0] == 20240520, f"Bad magic: {header[0]}" + assert header[1] in (1, 7), f"Bad version: {header[1]}" + ntok = int(header[2]) + return torch.from_numpy(np.fromfile(file, dtype=np.uint16, offset=256 * 4)[:ntok].astype(np.int64)) + + +class TokenStream: + """Reads shards sequentially, supports coprime ordering via SHARD_ORDER_FILE.""" + def __init__(self, pattern: str): + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if shard_order_file and os.path.exists(shard_order_file): + with open(shard_order_file) as f: + self.files = [Path(line.strip()) for line in f if line.strip()] + else: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + assert self.files, f"No files matching {pattern}" + self.idx = 0 + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def _advance_file(self) -> None: + self.idx = (self.idx + 1) % len(self.files) + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + parts = [] + remaining = n + while remaining > 0: + avail = self.buf.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take_n = min(avail, remaining) + parts.append(self.buf[self.pos:self.pos + take_n]) + self.pos += take_n + remaining -= take_n + return torch.cat(parts) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.stream = TokenStream(pattern) + self.rank = rank + self.world_size = world_size + self.device = device + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + tokens_per_rank = global_tokens // self.world_size + seqs_per_rank = tokens_per_rank // seq_len + total_seqs = seqs_per_rank * self.world_size + total_needed = total_seqs * seq_len + 1 + all_tokens = self.stream.take(total_needed) + start = self.rank * seqs_per_rank * seq_len + chunk = all_tokens[start:start + seqs_per_rank * seq_len + 1] + x = chunk[:-1].reshape(seqs_per_rank, seq_len) + y = chunk[1:].reshape(seqs_per_rank, seq_len) + return x.to(self.device), y.to(self.device) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = sorted(glob.glob(pattern)) + parts = [load_data_shard(Path(f)) for f in files] + combined = torch.cat(parts) + return combined[:((combined.numel() - 1) // seq_len) * seq_len + 1] + + +def build_sentencepiece_luts(sp, vocab_size, device): + base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + for i in range(vocab_size): + piece = sp.id_to_piece(i) + raw = piece.encode("utf-8") + base_bytes[i] = len(raw) + if piece.startswith("\u2581"): + has_space[i] = True + base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 + if sp.is_control(i) or sp.is_unknown(i): + is_boundary[i] = True + return base_bytes, has_space, is_boundary + + +def generate_coprime_shard_order(shard_files: list, seed: int = 42) -> list: + """Generate a coprime-stepping shard order for better data mixing. + + Instead of sequential 0,1,2,...,79, uses stride = coprime(N) to visit + all shards in a maximally-spread order. This ensures each training epoch + sees data from diverse shards rather than correlated sequential ones. + """ + n = len(shard_files) + if n <= 1: + return shard_files + + # Find a coprime stride that's roughly n/phi (golden ratio) + target = max(1, int(n / 1.618)) + stride = target + while math.gcd(stride, n) != 1: + stride += 1 + + rng = random.Random(seed) + start = rng.randint(0, n - 1) + order = [] + pos = start + for _ in range(n): + order.append(shard_files[pos]) + pos = (pos + stride) % n + return order + + +# ─── Muon Optimizer ────────────────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + super().__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + for p in group["params"]: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g + momentum * buf + else: + g = buf + if g.ndim == 2 and min(g.shape) >= 2: + g = zeropower_via_newtonschulz5(g, steps=group["backend_steps"]) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.data.add_(g, alpha=-lr) + + +# ─── Evaluation ────────────────────────────────────────────────────────────── + +def eval_val_sliding( + model: nn.Module, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + rank: int, + world_size: int, + device: torch.device, + seq_len: int = 1024, + stride: int = 64, + batch_seqs: int = 128, + xsa_eval: bool = False, + compile_enabled: bool = True, +) -> tuple[float, float]: + """Score-first sliding window evaluation.""" + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + base_model = model.module if hasattr(model, 'module') else model + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(True) + + forward_fn = base_model.forward_logits + compiled_logits = forward_fn + if compile_enabled: + try: + compiled_logits = torch.compile(forward_fn, dynamic=False) + except Exception: + compiled_logits = forward_fn + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(False) + + model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ─── TTT (Test-Time Training) ──────────────────────────────────────────────── + +def eval_val_ttt_gdn( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT for GDN: score each chunk with sliding windows, + then SGD-train on already-scored tokens. Every token scored BEFORE any update.""" + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_gdn:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks (early layers learn general features, keep stable) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_gdn:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + # Skip training on last chunk (no future windows benefit from it) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all parameters to requires_grad=True + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_gdn:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ─── Quantization ──────────────────────────────────────────────────────────── + +# Control tensor patterns — kept at full precision during quantization +CONTROL_PATTERNS = ( + "resid_mix", "q_gain", "smear", "skip_weight", "attn_scale", "mlp_scale", +) + + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + """Generate sequences autoregressively from the model for GPTQ calibration. + No external data accessed — fully self-contained.""" + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + + +def _compute_freq_weights(token_seqs, top_k=100, boost=2.0): + """Compute per-token frequency weights for FreqGPTQ. + Top-k most frequent tokens get sqrt(boost) multiplier on activations, + which translates to boost× weight in H = X^T X.""" + all_tokens = torch.cat([seq.reshape(-1) for seq in token_seqs]) + counts = torch.bincount(all_tokens, minlength=1) + top_ids = torch.topk(counts, min(top_k, len(counts))).indices + weight_map = torch.ones(len(counts), dtype=torch.float32) + weight_map[top_ids] = math.sqrt(boost) + return weight_map + + +def collect_hessians_from_tokens(hessian_model, token_seqs, device, freq_boost=2.0): + """Collect H = X^T X with FreqGPTQ: high-frequency tokens get boosted Hessian weight. + sqrt(boost) on activations → boost× on H = X^T X. Zero artifact cost.""" + freq_weights = _compute_freq_weights(token_seqs, top_k=100, boost=freq_boost) + freq_weights = freq_weights.to(device) + # Token IDs from each sequence for per-position weighting + _current_token_ids = [None] + + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + B, T, D = x.shape + x = x.reshape(-1, D) + # Apply per-token frequency weighting + tok_ids = _current_token_ids[0] + if tok_ids is not None: + w = freq_weights[tok_ids.reshape(-1)].unsqueeze(1) + x = x * w + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + _current_token_ids[0] = x + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + hessians[name] /= num_batches + return hessians + + +def quantize_intN_gptq(weight, hessian=None, n_bits=6, block_size=128): + """Full GPTQ: Hessian-aware N-bit quantization with Cholesky error compensation. + Supports int5 (clip_range=15), int6 (clip_range=31), int8 (clip_range=127). + If hessian is None, falls back to percentile search.""" + clip_range = (1 << (n_bits - 1)) - 1 + return _quantize_gptq_inner(weight, hessian, clip_range, block_size) + + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): + """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. + If hessian is None, falls back to percentile search.""" + return _quantize_gptq_inner(weight, hessian, clip_range, block_size) + + +def _quantize_gptq_inner(weight, hessian=None, clip_range=31, block_size=128): + """Inner GPTQ implementation for any clip range.""" + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return quantize_int6_per_row(t32, clip_range=clip_range) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + + +def quantize_intN_per_row(t: Tensor, n_bits: int = 6) -> tuple[Tensor, Tensor]: + """N-bit quantization with percentile search for optimal clipping.""" + return quantize_int6_per_row(t, clip_range=(1 << (n_bits - 1)) - 1) + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + """Int6 quantization with percentile search for optimal clipping.""" + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + # 1D: simple per-tensor quantization + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def quantize_int8_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """Int8 quantization with percentile clipping.""" + t32 = t.float() + clip_q = 0.9999984 + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), clip_q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0).to(torch.float16) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -127, 127).to(torch.int8) + return q, scale + clip_abs = float(torch.quantile(t32.abs().flatten(), clip_q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale.float()), -127, 127).to(torch.int8) + return q, scale + + +def _get_num_blocks(state_dict: dict[str, Tensor]) -> int: + """Infer the number of blocks from state dict keys.""" + block_ids = set() + for name in state_dict: + if "blocks." in name: + parts = name.split(".") + idx = parts.index("blocks") + 1 + if idx < len(parts) and parts[idx].isdigit(): + block_ids.add(int(parts[idx])) + return max(block_ids) + 1 if block_ids else 0 + + +def mixed_quantize(state_dict: dict[str, Tensor], hessians: dict[str, Tensor] | None = None, + freq_token_ids: Tensor | None = None, + weight_bits: int = 6) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed intN/int8 quantization with PassthroughQuant, sandwich, and adaptive embeddings. + + Improvements over baseline: + - Configurable weight_bits (5, 6, or 8) for large 2D matrices + - PassthroughQuant: control tensors → int8 (not fp16), saves ~40KB + - Sandwich: final transformer layer → int8 (protects LM head signal) + - Adaptive embeddings: top-100 freq token rows → int8, rest → intN + - FreqGPTQ: Hessian collection is frequency-weighted (handled upstream) + """ + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + num_blocks = _get_num_blocks(state_dict) + last_block = num_blocks - 1 if num_blocks > 0 else -1 + clip_range = (1 << (weight_bits - 1)) - 1 + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + + # Control tensors: PassthroughQuant → int8 instead of fp16 passthrough + if any(p in name for p in CONTROL_PATTERNS): + if t.is_floating_point(): + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + result[name] = t + meta[name] = "passthrough" + continue + + # Non-float: passthrough + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + + # Embedding weights: adaptive precision (int8 for freq tokens, int6 for rest) + if "tok_emb" in name and t.ndim == 2 and freq_token_ids is not None: + # Split: frequent rows get int8, rest get intN + freq_mask = torch.zeros(t.shape[0], dtype=torch.bool) + freq_mask[freq_token_ids] = True + # Int8 for frequent tokens + freq_rows = t[freq_mask] + q8, s8 = quantize_int8_per_row(freq_rows) + # IntN for remaining tokens (uses weight_bits) + rest_rows = t[~freq_mask] + q6, s6 = quantize_intN_per_row(rest_rows, n_bits=weight_bits) + result[name + ".q8"] = q8 + result[name + ".s8"] = s8 + result[name + ".q6"] = q6 + result[name + ".s6"] = s6 + result[name + ".freq_mask"] = freq_mask + meta[name] = {"type": "adaptive_embed"} + continue + + # Small tensors: int8 (PassthroughQuant for 2D too) + if t.numel() <= 65536: + if t.ndim >= 1: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + + # Large 2D weights + if t.ndim == 2 and t.numel() > 65536: + # Sandwich: final block gets int8 for better LM head signal + is_last_block = f"blocks.{last_block}." in name + if is_last_block: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + else: + H = hessians.get(name) if hessians else None + q, s = quantize_intN_gptq(t, hessian=H, n_bits=weight_bits) if H is not None else quantize_intN_per_row(t, n_bits=weight_bits) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{weight_bits}"} + else: + # Other float tensors: int8 + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Dequantize mixed int6/int8/adaptive_embed back to float.""" + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info == "passthrough": + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + # Adaptive embedding: reconstruct from split int8/int6 + if isinstance(info, dict) and info.get("type") == "adaptive_embed": + freq_mask = result[name + ".freq_mask"] + q8, s8 = result[name + ".q8"], result[name + ".s8"] + q6, s6 = result[name + ".q6"], result[name + ".s6"] + full = torch.zeros_like(orig) + full[freq_mask] = (q8.float() * s8.float().view(q8.shape[0], *([1] * (q8.ndim - 1)))).to(orig_dtype) + full[~freq_mask] = (q6.float() * s6.float().view(q6.shape[0], *([1] * (q6.ndim - 1)))).to(orig_dtype) + out[name] = full + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ─── Checkpoint Saving ─────────────────────────────────────────────────────── + +def save_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": base.state_dict(), + } + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"{arch_name}_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def save_full_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + qat_enabled, rng_states=None, stream_state=None): + """Save complete training state for chained job resume.""" + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": {k: v.cpu() for k, v in base.state_dict().items()}, + "muon_opt_state": muon_opt.state_dict(), + "adam_opt_state": adam_opt.state_dict(), + "ema_state": {k: v.cpu() for k, v in ema_state.items()}, + "swa_state": {k: v.cpu() for k, v in swa_state.items()} if swa_state is not None else None, + "swa_count": swa_count, + "qat_enabled": qat_enabled, + } + if rng_states is not None: + ckpt["rng_states"] = rng_states + if stream_state is not None: + ckpt["stream_state"] = stream_state + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"full_ckpt_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def _find_latest_full_ckpt(ckpt_dir): + """Find the latest full_ckpt_step*.pt file in ckpt_dir by step number.""" + pattern = os.path.join(ckpt_dir, "full_ckpt_step*_seed*.pt") + files = glob.glob(pattern) + if not files: + return None + import re + step_re = re.compile(r"full_ckpt_step(\d+)_seed") + best_step, best_path = -1, None + for f in files: + m = step_re.search(os.path.basename(f)) + if m: + s = int(m.group(1)) + if s > best_step: + best_step, best_path = s, f + return best_path + + +# ─── Main Training Loop ───────────────────────────────────────────────────── + +def main(): + global zeropower_via_newtonschulz5 + args = Hyperparameters() + config = get_config(args.arch_mode) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = max(1, 8 // world_size) + master_process = rank == 0 + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if args.compile_enabled: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # Logging + os.makedirs("logs", exist_ok=True) + os.makedirs(args.ckpt_dir, exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master_process else None + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg, flush=True) + if logfile: + with open(logfile, "a") as f: + print(msg, file=f) + + # Seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + log0(f"=== GDN Hybrid 7k Full Training ===") + log0(f"Arch: {config['arch_name']} (ARCH_MODE={args.arch_mode})") + log0(f"Seed: {args.seed}, Steps: {args.iterations}, Warmdown: {args.warmdown_iters}") + log0(f"World size: {world_size}, Grad accum: {grad_accum_steps}") + log0(f"EMA decay: {args.ema_decay}, SWA: {args.swa_enabled} (every {args.swa_every})") + weight_bits_cfg = int(os.environ.get("WEIGHT_BITS", "6")) + CastedLinear._qat_clip_range = (1 << (weight_bits_cfg - 1)) - 1 + log0(f"Late QAT threshold: {args.late_qat_threshold} (int{weight_bits_cfg}, clip_range={CastedLinear._qat_clip_range})") + log0(f"Eval compile enabled: {args.eval_compile_enabled}") + + # Tokenizer + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size + + # Validation data + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"Validation tokens: {val_tokens.numel()-1:,}") + + # Build model + _t0 = time.time() + model = HybridGDN(config, args.vocab_size) + model = model.to(device).bfloat16() + log0(f"Model built in {time.time()-_t0:.1f}s") + + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, p in model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + + param_counts = model.count_params() + log0(f"Parameters: {param_counts}") + log0(f"Total params: {param_counts['total']:,}") + + # Resume from checkpoint if specified + start_step = 0 + resume_state = None # holds full checkpoint data for deferred restore + resume_ckpt_path = args.resume_ckpt + if resume_ckpt_path == "auto": + resume_ckpt_path = _find_latest_full_ckpt(args.ckpt_dir) or "" + if resume_ckpt_path: + log0(f"Auto-detected resume checkpoint: {resume_ckpt_path}") + else: + log0("Auto-resume: no full checkpoint found, starting fresh") + if resume_ckpt_path and os.path.exists(resume_ckpt_path): + log0(f"Resuming from checkpoint: {resume_ckpt_path}") + ckpt = torch.load(resume_ckpt_path, map_location="cpu", weights_only=False) + base_sd = ckpt["model_state_dict"] + model.load_state_dict({k: v.to(device) for k, v in base_sd.items()}, strict=True) + start_step = ckpt.get("step", 0) + log0(f"Resumed model at step {start_step}, val_bpb={ckpt.get('val_bpb', 'N/A')}") + # Keep full checkpoint for deferred optimizer/EMA/SWA restore + if "muon_opt_state" in ckpt: + resume_state = ckpt + log0(" Full checkpoint detected — will restore optimizers, EMA, SWA, RNG") + else: + log0(" Lightweight checkpoint — model only") + del ckpt + + # DDP + base_model = model # keep reference before wrapping + if distributed: + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + # Optimizer setup + matrix_params = [] + scalar_params = [] + embed_params = [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "tok_emb" in name: + embed_params.append(p) + elif p.ndim == 2 and min(p.shape) >= 2: + matrix_params.append(p) + else: + scalar_params.append(p) + + log0(f"Matrix params: {sum(p.numel() for p in matrix_params):,}") + log0(f"Scalar params: {sum(p.numel() for p in scalar_params):,}") + log0(f"Embed params: {sum(p.numel() for p in embed_params):,}") + + muon_opt = Muon( + matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + adam_opt = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr}, + {"params": embed_params, "lr": args.tied_embed_lr}], + betas=(args.beta1, args.beta2), + weight_decay=args.adam_wd, + fused=True, + ) + + # Deferred restore: optimizer states (must happen after optimizer creation) + if resume_state is not None: + muon_opt.load_state_dict(resume_state["muon_opt_state"]) + adam_opt.load_state_dict(resume_state["adam_opt_state"]) + log0(" Restored optimizer states (Muon + Adam)") + + # Data loader — coprime shard ordering + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if not shard_order_file: + # Generate coprime ordering on-the-fly + shard_files = sorted(glob.glob(args.train_files)) + if shard_files: + ordered = generate_coprime_shard_order(shard_files, seed=args.seed) + shard_order_path = f"/tmp/shard_order_{args.run_id}.txt" + with open(shard_order_path, "w") as f: + for sf in ordered: + f.write(str(sf) + "\n") + os.environ["SHARD_ORDER_FILE"] = shard_order_path + log0(f"Generated coprime shard order: stride across {len(shard_files)} shards") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # LR schedule with warmdown (cosine) + def lr_schedule(step: int) -> float: + warmdown_start = args.iterations - args.warmdown_iters + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + elif step >= warmdown_start: + progress = (step - warmdown_start) / args.warmdown_iters + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + + # ─── EMA + SWA state ───────────────────────────────────────────────── + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Deferred restore: EMA, SWA, QAT, RNG, stream state + if resume_state is not None: + # EMA + saved_ema = resume_state.get("ema_state") + if saved_ema is not None: + ema_state = {k: v.to(device).float() for k, v in saved_ema.items()} + log0(" Restored EMA state") + # SWA + saved_swa = resume_state.get("swa_state") + if saved_swa is not None: + swa_state = {k: v.cpu() for k, v in saved_swa.items()} + swa_count = resume_state.get("swa_count", 0) + log0(f" Restored SWA state (count={swa_count})") + else: + swa_count = resume_state.get("swa_count", 0) + # QAT + if resume_state.get("qat_enabled", False): + CastedLinear._qat_enabled = True + log0(" Restored QAT enabled state") + # RNG states + saved_rng = resume_state.get("rng_states") + if saved_rng is not None: + torch.set_rng_state(saved_rng["torch_cpu"]) + torch.cuda.set_rng_state(saved_rng["torch_cuda"]) + np.random.set_state(saved_rng["numpy"]) + random.setstate(saved_rng["python"]) + log0(" Restored RNG states") + # Stream state (data loader fast-forward) + saved_stream = resume_state.get("stream_state") + if saved_stream is not None: + s_idx, s_pos = saved_stream + stream = train_loader.stream + # Advance to the saved shard + while stream.idx != s_idx: + stream._advance_file() + stream.pos = s_pos + log0(f" Restored stream state (shard={s_idx}, pos={s_pos})") + else: + # No stream state saved — fast-forward by consuming tokens + if start_step > 0: + log0(f" Fast-forwarding data loader by {start_step} steps...") + for _ in range(start_step): + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + log0(f" Data loader advanced to step {start_step}") + del resume_state + log0(" Full checkpoint restore complete") + + # ─── Training Loop ─────────────────────────────────────────────────── + # Clear stale chain marker from previous segment (if any) + stale_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(stale_marker): + os.remove(stale_marker) + + log0(f"\n{'='*80}") + log0(f"Starting training: {args.iterations} steps (from step {start_step})") + log0(f"{'='*80}\n") + + t0 = time.time() + running_loss = 0.0 + loss_count = 0 + stop_after_step = None + step = start_step # ensure step is defined even if loop doesn't execute + + for step in range(start_step + 1, args.iterations + 1): + # Check early stop + if stop_after_step is not None and step > stop_after_step: + log0(f"Stopping early at step {step} (wallclock limit)") + break + + lr_mul = lr_schedule(step) + + # Muon momentum warmup + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + current_muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in muon_opt.param_groups: + group["lr"] = args.matrix_lr * lr_mul + group["momentum"] = current_muon_momentum + for i, pg in enumerate(adam_opt.param_groups): + if i == 0: + pg["lr"] = args.scalar_lr * lr_mul + else: + pg["lr"] = args.tied_embed_lr * lr_mul + + # Late QAT: activate int6 STE only during warmdown (not warmup!) + warmdown_start = args.iterations - args.warmdown_iters + if (args.late_qat_threshold > 0 and step >= warmdown_start + and lr_mul < args.late_qat_threshold and not CastedLinear._qat_enabled): + CastedLinear._qat_enabled = True + log0(f"Late QAT enabled at step {step} (lr_mul={lr_mul:.4f})") + + # Gradient accumulation + model.train() + total_loss = 0.0 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + micro_batch = x.shape[0] // grad_accum_steps + for micro_step in range(grad_accum_steps): + x_micro = x[micro_step * micro_batch:(micro_step + 1) * micro_batch] + y_micro = y[micro_step * micro_batch:(micro_step + 1) * micro_batch] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x_micro, y_micro) + loss = loss / grad_accum_steps + loss.backward() + total_loss += loss.item() + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + + muon_opt.step() + adam_opt.step() + muon_opt.zero_grad(set_to_none=True) + adam_opt.zero_grad(set_to_none=True) + + # EMA update (every step) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + # SWA: collect checkpoints during late warmdown + if args.swa_enabled and lr_mul < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"SWA started at step {step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + running_loss += total_loss + loss_count += 1 + + # Logging + if step % args.train_log_every == 0 or step <= 10: + avg_loss = running_loss / max(loss_count, 1) + elapsed = time.time() - t0 + steps_per_sec = step / elapsed + log0(f"step {step:5d}/{args.iterations} | loss {avg_loss:.4f} | lr_mul {lr_mul:.4f} | " + f"mom {current_muon_momentum:.3f} | {steps_per_sec:.2f} steps/s | {elapsed:.0f}s") + running_loss = 0.0 + loss_count = 0 + + # Validation + checkpoint + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or step == args.iterations: + val_loss, val_bpb = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=args.xsa_eval, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"step {step:5d} | val_loss {val_loss:.4f} | val_bpb {val_bpb:.4f}") + + if master_process and args.save_every > 0 and (step % args.save_every == 0 or step == args.iterations): + ckpt_path = save_checkpoint( + model, step, val_bpb, args.ckpt_dir, config["arch_name"], args.seed, + ) + log0(f" Saved: {ckpt_path}") + + # Wallclock limit + if args.max_wallclock_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.max_wallclock_seconds and stop_after_step is None: + stop_after_step = step + log0(f"Wallclock limit reached ({elapsed:.0f}s), will stop after this step") + + # Auto-save for chained job support + if args.auto_save_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.auto_save_seconds: + log0(f"Auto-save triggered at step {step} ({elapsed:.0f}s elapsed)") + if master_process: + rng_states = { + "torch_cpu": torch.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + stream = train_loader.stream + stream_state = (stream.idx, stream.pos) + ckpt_path = save_full_checkpoint( + model, step, 0.0, args.ckpt_dir, config["arch_name"], args.seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + CastedLinear._qat_enabled, + rng_states=rng_states, stream_state=stream_state, + ) + # Write chain resume marker + marker_path = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + with open(marker_path, "w") as f: + f.write(ckpt_path + "\n") + log0(f" Full checkpoint saved: {ckpt_path}") + log0(f" Chain resume marker: {marker_path}") + break # exit training loop cleanly + + # ─── Check if we exited due to auto-save vs normal completion ──────── + chain_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(chain_marker): + log0("\nExiting for chained job resume (skipping post-training)") + if distributed: + dist.destroy_process_group() + return + + # Check for total_iterations completion + effective_total = args.total_iterations if args.total_iterations > 0 else args.iterations + if master_process and step >= effective_total: + complete_marker = os.path.join(args.ckpt_dir, f"TRAINING_COMPLETE_seed{args.seed}") + with open(complete_marker, "w") as f: + f.write(f"step={step}\n") + + # ─── Post-Training: Apply EMA ──────────────────────────────────────── + elapsed_total = time.time() - t0 + log0(f"\nTraining complete in {elapsed_total:.0f}s") + log0(f"Peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + log0("\n=== Applying EMA weights ===") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Eval EMA weights + val_loss_ema, val_bpb_ema = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"EMA BPB (no XSA): {val_bpb_ema:.6f}") + + # Save raw EMA model + if master_process: + torch.save(base_model.state_dict(), os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.pt")) + log0("Saved raw EMA model") + + # ─── GPTQ Calibration (FreqGPTQ: frequency-weighted Hessian) ──────── + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + hessians = None + if gptq_enabled: + log0("\n=== FreqGPTQ: generating calibration data + frequency-weighted Hessians ===") + calib_seqs = generate_autoregressive_calib( + base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, + ) + log0(f"FreqGPTQ: generated {len(calib_seqs)} sequences, collecting weighted hessians...") + freq_boost = float(os.environ.get("FREQ_GPTQ_BOOST", "2.0")) + hessians = collect_hessians_from_tokens(base_model, calib_seqs, device, freq_boost=freq_boost) + log0(f"FreqGPTQ: collected hessians for {len(hessians)} layers (boost={freq_boost})") + + # ─── Compute frequent token IDs for adaptive embedding quantization ── + freq_token_ids = None + adaptive_embed = bool(int(os.environ.get("ADAPTIVE_EMBED", "1"))) + if adaptive_embed: + num_freq_tokens = int(os.environ.get("NUM_FREQ_TOKENS", "100")) + # Collect token frequency from training data + train_stream = TokenStream(args.train_files) + sample_tokens = train_stream.take(min(1_000_000, train_stream.buf.numel())) + token_counts = torch.bincount(sample_tokens, minlength=args.vocab_size) + freq_token_ids = torch.topk(token_counts, min(num_freq_tokens, args.vocab_size)).indices + log0(f"Adaptive embed: {len(freq_token_ids)} frequent tokens get int8 (rest int6)") + + # ─── Quantization + Artifact Creation ──────────────────────────────── + weight_bits = int(os.environ.get("WEIGHT_BITS", "6")) + log0(f"\n=== Quantizing: int{weight_bits}+FreqGPTQ + int8 sandwich + adaptive embed + zstd-22 ===") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, hessians=hessians, freq_token_ids=freq_token_ids, weight_bits=weight_bits) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + + artifact_path = os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.int6.ptz") + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + artifact_bytes = len(quant_blob) + log0(f"Artifact: {artifact_bytes:,} bytes ({artifact_bytes / 1024 / 1024:.2f} MB)") + if artifact_bytes > 16 * 1024 * 1024: + log0(f"WARNING: Artifact exceeds 16MB budget by {(artifact_bytes - 16*1024*1024) / 1024:.1f} KB") + + # ─── Roundtrip Validation ──────────────────────────────────────────── + log0("\n=== Roundtrip Validation (quantized model) ===") + if distributed: + dist.barrier() + + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + # Build fresh eval model (no DDP wrapping needed) + eval_model = HybridGDN(config, args.vocab_size).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + for name, p in eval_model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + eval_model.load_state_dict(deq_state, strict=True) + + # Eval quantized model without XSA + val_loss_q, val_bpb_q = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (no XSA): {val_bpb_q:.6f}") + log0(f"Quantization degradation: {val_bpb_q - val_bpb_ema:+.6f}") + + # Eval quantized model WITH XSA (if model has SWA layers) + block_types = eval_model._block_types + if any(bt in ("swa", "swa_shared") for bt in block_types): + val_loss_qx, val_bpb_qx = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=True, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (XSA-all): {val_bpb_qx:.6f}") + + # ─── Final Summary ─────────────────────────────────────────────────── + log0(f"\n{'='*80}") + log0(f"FINAL RESULTS — {config['arch_name']} seed={args.seed}") + log0(f" Training: {args.iterations} steps, {elapsed_total:.0f}s") + log0(f" EMA BPB (fp32): {val_bpb_ema:.6f}") + log0(f" Quantized BPB: {val_bpb_q:.6f}") + if any(bt in ("swa", "swa_shared") for bt in block_types): + log0(f" Quantized BPB+XSA: {val_bpb_qx:.6f}") + if master_process: + log0(f" Artifact size: {artifact_bytes:,} bytes") + log0(f"{'='*80}") + log0(f"final_int6_roundtrip_exact val_loss:{val_loss_q:.8f} val_bpb:{val_bpb_q:.8f}") + + # ─── Legal Score-First TTT ──────────────────────────────────────────── + if args.ttt_enabled: + log0("\n=== Legal Score-First TTT (GDN) ===") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_ttt_gdn( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.ttt_batch_seqs, log0=log0, + ) + torch.cuda.synchronize() + ttt_elapsed = time.perf_counter() - t_ttt + ttt_delta = ttt_bpb - val_bpb_q + log0(f"TTT BPB: {ttt_bpb:.6f} (delta: {ttt_delta:+.6f})") + log0(f"TTT eval time: {ttt_elapsed:.1f}s") + log0(f"final_int6_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gpt.py b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gpt.py new file mode 100644 index 0000000000..c92492b069 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_FreqGPTQ_GatedDeltaNet_AdaptiveQuant/train_gpt.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""FLA / GatedDeltaNet entrypoint wrapper. + +The actual training logic lives in `train_gdn_7k.py`. `evaluate.py` expects +`torchrun train_gpt.py`, so this wrapper preserves the standard repo entrypoint +while keeping the scored path in the records folder self-contained. +""" + +import os +import sys +import traceback +from pathlib import Path + +# These defaults keep the wrapper aligned with the intended SP8192 scored path. +VOCAB_SIZE = int(os.environ.get("VOCAB_SIZE", 8192)) +DATA_PATH = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") +ARCH_MODE = os.environ.get("ARCH_MODE", "K") +os.environ.setdefault("VOCAB_SIZE", str(VOCAB_SIZE)) +os.environ.setdefault("DATA_PATH", DATA_PATH) +os.environ.setdefault("TOKENIZER_PATH", TOKENIZER_PATH) +os.environ.setdefault("ARCH_MODE", ARCH_MODE) +os.environ.setdefault("MAX_WALLCLOCK_SECONDS", "600") +os.environ.setdefault("VAL_LOSS_EVERY", "0") +os.environ.setdefault("EVAL_COMPILE_ENABLED", "0") +if ARCH_MODE in ("D", "G", "M"): + os.environ.setdefault("XSA_EVAL", "1") + + +_VENDOR_DIR = Path(__file__).resolve().parent / ".fla_vendor" +_VENDOR_PKGS = [ + "triton==3.2.0", + "flash-linear-attention==0.4.2", + "fla-core==0.4.2", + "transformers==5.5.4", + "tokenizers==0.22.2", + "safetensors==0.7.0", +] +if ARCH_MODE in ("F", "G"): + _VENDOR_PKGS.extend( + [ + "mamba-ssm==2.3.1", + "causal-conv1d==1.6.1", + ] + ) + + +def _ensure_vendor_on_path() -> None: + p = str(_VENDOR_DIR) + if p not in sys.path: + sys.path.insert(0, p) + + +def _ensure_fla_vendor_available() -> None: + _ensure_vendor_on_path() + try: + if ARCH_MODE in ("F", "G"): + from fla.layers.mamba2 import Mamba2 # noqa: F401 + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa: F401 + from causal_conv1d import causal_conv1d_fn # noqa: F401 + else: + from fla.layers.gated_deltanet import GatedDeltaNet # noqa: F401 + print("wrapper: local vendored FLA imports already work", flush=True) + return + except Exception: + vendor_pkgs = ", ".join(_VENDOR_PKGS) + raise RuntimeError( + "wrapper: required FLA deps are missing from the local environment. " + f"Expected vendored packages under {_VENDOR_DIR}. " + f"Install them before evaluation (e.g. via launcher/requirements), packages: {vendor_pkgs}" + ) + + +def main(): + _ensure_fla_vendor_available() + print("wrapper: importing train_gdn_7k", flush=True) + try: + from train_gdn_7k import main as train_main + except Exception: + traceback.print_exc() + raise + print("wrapper: import ok, entering train_main", flush=True) + train_main() + + +if __name__ == "__main__": + main()