diff --git a/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/README.md b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/README.md new file mode 100644 index 0000000000..3492d18033 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/README.md @@ -0,0 +1,95 @@ +# DeepQuant — 11L INT6 + 8-epoch Cosine LoRA TTT + +**val_bpb: 0.5850** (seed=42, eval 582s, 15.46MB) + +## Approach + +We explored how far per-document test-time training can push a small 16MB language model. The core hypothesis: a well-trained base model combined with aggressive per-document LoRA adaptation at eval time can dramatically reduce bits-per-byte by specializing the model to each document's distribution. + +## Architecture + +Standard 11-layer transformer backbone: +- dim=512, 8 attention heads, 4 KV heads (GQA), MLP expansion 3x (1536) +- BigramHash(2048) + SmearGate for parameter-efficient bigram context +- U-Net skip connections between encoder/decoder layer pairs +- Depth-scaled residuals: 1/sqrt(layer+1) for stable deep training +- RoPE positional encoding (base=50000) +- Logit softcap=30.0 + +## Training (600s, 8xH100 SXM) + +- Muon optimizer (Newton-Schulz whitening) for matrix params + AdamW for scalars/embeddings +- Wallclock-based LR schedule with warmdown +- EMA (decay=0.999, every 10 steps) + SWA (12 checkpoints in final warmdown) +- ~7100 training steps, batch tokens=786,432 +- INT6 uniform quantization (64 levels per row) + zstd-22 compression +- 4% magnitude pruning before quantization + +## Test-Time Training (TTT) — Key Innovation + +Per-document LoRA adaptation at eval time with several design choices that proved critical: + +### 1. 8-epoch multi-pass adaptation +Each document gets 8 full passes of LoRA training. We found TTT gain scales strongly with epoch count — each additional epoch provides meaningful BPB improvement as the LoRA captures deeper document-specific patterns. + +### 2. Score-every-epoch compliance +Every token is scored before being trained on, in every epoch. Scores are overwritten each epoch, so the final score reflects the most adapted LoRA state. This satisfies backward-looking TTT requirements. + +### 3. Cosine LR decay within TTT +Per-step cosine schedule (from base LR down to 10%) across all epochs×chunks steps. This prevents overfitting in later passes while allowing aggressive early adaptation. Constant LR overshoots on later chunks. + +### 4. LM-head LoRA rank-16 +The output projection (dim→vocab) is the highest-leverage layer for BPB. We use rank-16 for the LM-head LoRA while keeping rank-8 for Q/V projections. This doubles the model's capacity to adapt its output distribution per document. + +### 5. Per-block bias tuning +During TTT, we tune a bias vector (512 params) per transformer block alongside LoRA. This provides a cheap "domain shift" — adjusting activation means to match document statistics without extra matmul cost. + +### 6. Post-TTT temperature rescaling (T=0.98) +Multi-epoch LoRA adaptation tends to make the model overconfident. Scaling logits by 0.98 corrects this calibration error for a consistent ~0.003 BPB improvement at zero compute cost. + +### 7. Zigzag GPU load balancing +Documents are distributed across 8 GPUs using a zigzag pattern (GPU 0→7, then 7→0, repeating) instead of contiguous blocks. This ensures each GPU processes a balanced mix of document lengths, eliminating a ~220s synchronization bottleneck from GPU workload imbalance. + +### 8. Outlier document filtering +Documents exceeding 50,000 tokens are scored with the base model without TTT. These extreme outliers take disproportionate compute (quadratic in chunk count) while being too few to meaningfully affect average BPB. + +### 9. Wall-clock TTT budget +A configurable time limit (570s default) on the TTT batch loop. If exceeded, remaining documents fall back to batched base-model scoring. This guarantees eval completes within the 600s budget. + +## TTT Configuration + +| Parameter | Value | +|-----------|-------| +| LoRA rank (Q, V) | 8 | +| LoRA rank (LM-head) | 16 | +| TTT LR | 0.01 (Adam, betas=0.9/0.95) | +| TTT epochs | 8 | +| TTT chunk size | 256 | +| TTT batch size | 64 documents | +| TTT min doc length | 512 tokens | +| TTT max doc length | 50,000 tokens | +| Temperature rescale | 0.98 | +| Cosine LR | enabled (min 10%) | +| Bias tuning | enabled | + +## How to run + +```bash +DATA_PATH=/path/to/fineweb10B_sp1024 \ +TOKENIZER_PATH=/path/to/fineweb_1024_bpe.model \ +SEED=42 TTT_EPOCHS=8 \ +torchrun --nproc_per_node=8 train_gpt.py +``` + +## Timing breakdown + +| Phase | Time | +|-------|------| +| Training | 600s | +| Post-processing (SWA+EMA+pruning) | <1s | +| Serialization (quant+compress) | 38s | +| Post-quant eval | 5s | +| TTT eval (short docs) | 22s | +| TTT eval (long docs, 62 batches) | 559s | +| TTT overhead | 2s | +| **Total eval** | **582s** | diff --git a/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/submission.json b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/submission.json new file mode 100644 index 0000000000..e9c84b3fb9 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/submission.json @@ -0,0 +1,10 @@ +{ + "author": "UrukHan", + "github_id": "UrukHan", + "name": "DeepQuant — 11L INT6 + 8-epoch Cosine LoRA TTT", + "blurb": "8ep cosine TTT + LM rank-16 + bias tuning + zigzag GPU balancing", + "date": "2026-03-24T00:00:00Z", + "val_loss": 0.9878, + "val_bpb": 0.5850, + "bytes_total": 15463955 +} diff --git a/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_gpt.py b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_gpt.py new file mode 100644 index 0000000000..5f0fd07d93 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_gpt.py @@ -0,0 +1,1481 @@ +"""Good launching-off point for new participants, not SOTA config. Competitive submissions stay in /records. +Hard stop: train_gpt.py and train_gpt_mlx.py must never be longer than 1500 lines.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + HAVE_ZSTD = True +except ImportError: + HAVE_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) # disabled: hurts with depth_scale, wastes 15 min + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + ema_decay = float(os.environ.get("EMA_DECAY", 0.999)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_every = int(os.environ.get("EMA_EVERY", 10)) + + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lm_rank = int(os.environ.get("TTT_LM_RANK", 16)) # V6: larger LM-head rank + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_min_doc_len = int(os.environ.get("TTT_MIN_DOC_LEN", 512)) + ttt_max_doc_len = int(os.environ.get("TTT_MAX_DOC_LEN", 50000)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 6)) # V8: 6 epochs + score every epoch + ttt_cosine_lr = bool(int(os.environ.get("TTT_COSINE_LR", "1"))) + ttt_bias_tune = bool(int(os.environ.get("TTT_BIAS_TUNE", "1"))) + ttt_temp_rescale = float(os.environ.get("TTT_TEMP_RESCALE", 0.98)) + ttt_max_eval_secs = float(os.environ.get("TTT_MAX_EVAL_SECS", 570.0)) # V8: post-TTT calibration + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0: + p.data.mul_(1.0 - wd * lr) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=6) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class SmearGate(nn.Module): + """Learned token blending gate — injects bigram context at embedding layer.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + """Token-pair hash embedding — learned bigram features at near-zero param cost.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, mlp_hidden: int = 0, layer_idx: int = 0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.register_buffer("depth_scale", torch.tensor(1.0 / math.sqrt(layer_idx + 1))) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + ds = self.depth_scale.to(dtype=x.dtype) + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + ds * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + ds * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: int, mlp_hidden: int, tie_embeddings: bool, + tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(2048, 128, model_dim) + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + mlp_hidden=mlp_hidden, layer_idx=i) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _embed(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Shared embedding logic for forward and get_logits.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + return x, x # (x, x0) + + def _run_blocks(self, x: Tensor, x0: Tensor, lora=None) -> Tensor: + """Run all transformer blocks with optional LoRA deltas + V6 bias tuning.""" + skips: list[Tensor] = [] + has_bias = lora is not None and len(lora.bias_params) > 0 + for i in range(self.num_encoder_layers): + qd_fn = lora.q_loras[i] if lora is not None else None + vd_fn = lora.v_loras[i] if lora is not None else None + x = self.blocks[i](x, x0, qd_fn, vd_fn) + if has_bias: + x = x + lora.bias_params[2*i].to(dtype=x.dtype) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd_fn = lora.q_loras[bi] if lora is not None else None + vd_fn = lora.v_loras[bi] if lora is not None else None + x = self.blocks[bi](x, x0, qd_fn, vd_fn) + if has_bias: + x = x + lora.bias_params[2*bi].to(dtype=x.dtype) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, lora) + x_norm = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x_norm.reshape(-1, x_norm.size(-1)), self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head required when tie_embeddings=False") + logits_proj = self.lm_head(x_norm.reshape(-1, x_norm.size(-1))) + if lora is not None: + lora_delta = lora.lm_head_lora(x_norm) # (bsz, seqlen, V) + bsz, seqlen, V = lora_delta.shape + logits = logits_proj.reshape(bsz, seqlen, V) + lora_delta + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, seqlen) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + @torch.no_grad() + def get_logits(self, input_ids: Tensor, lora=None) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, lora) + x_norm = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x_norm, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x_norm) + if lora is not None: + logits_proj = logits_proj + lora.lm_head_lora(x_norm) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """Per-batch-element LoRA adapter for a linear layer. Delta = x @ Aᵀ @ Bᵀ.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """V6 Multi-Scale TTT: LM-head rank-16, Q/V rank-8, optional bias tuning. + Per-layer LR groups: LM-head 2x, V 1.5x, Q 0.5x for optimal adaptation.""" + def __init__(self, bsz: int, model: GPT, rank: int, lm_rank: int = 16, + tune_biases: bool = False): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, lm_rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + q_out = block.attn.c_q.weight.shape[0] + v_out = block.attn.c_v.weight.shape[0] + self.q_loras.append(BatchedLinearLoRA(bsz, dim, q_out, rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, v_out, rank)) + # V6: optional bias vectors for norm layers (cheap but effective domain shift) + self.bias_params = nn.ParameterList() + if tune_biases: + for block in model.blocks: + self.bias_params.append(nn.Parameter(torch.zeros(bsz, 1, dim))) + self.bias_params.append(nn.Parameter(torch.zeros(bsz, 1, dim))) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + for p in self.bias_params: + p.data.zero_() + +def _reset_ttt_optimizer(opt: torch.optim.Adam) -> None: + for group in opt.param_groups: + for p in group["params"]: + s = opt.state.get(p) + if not s: + continue + s["exp_avg"].zero_() + s["exp_avg_sq"].zero_() + s["step"].fill_(0) + +def _build_ttt_optimizer(lora: BatchedTTTLoRA, args: Hyperparameters) -> torch.optim.Adam: + """V6: per-layer LR groups — LM-head 2x, V 1.5x, Q 0.5x, bias 3x.""" + base_lr = args.ttt_lora_lr + groups = [ + {"params": list(lora.lm_head_lora.parameters()), "lr": base_lr * 2.0, "base_lr": base_lr * 2.0}, + {"params": [p for lora_m in lora.v_loras for p in lora_m.parameters()], "lr": base_lr * 1.5, "base_lr": base_lr * 1.5}, + {"params": [p for lora_m in lora.q_loras for p in lora_m.parameters()], "lr": base_lr * 0.5, "base_lr": base_lr * 0.5}, + ] + if lora.bias_params: + groups.append({"params": list(lora.bias_params), "lr": base_lr * 3.0, "base_lr": base_lr * 3.0}) + return torch.optim.Adam(groups, lr=base_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document at BOS boundaries.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].cpu().numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) + 1 if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk ci of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """TTT eval: per-doc LoRA adaptation, score-then-train, multiple epochs.""" + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + _zz = [] + for i in range(0, len(docs), world_size): + c = docs[i:i+world_size] + if (i // world_size) % 2 == 1: c = c[::-1] + _zz.extend(c) + rank_docs = _zz[rank::world_size] + short_docs = [d for d in rank_docs if d[1] < args.ttt_min_doc_len] + long_docs = [d for d in rank_docs if d[1] >= args.ttt_min_doc_len and d[1] <= args.ttt_max_doc_len] + outlier_docs = [d for d in rank_docs if d[1] > args.ttt_max_doc_len] + short_docs = short_docs + outlier_docs + master = rank == 0 + if master: + print(f"ttt:rank0 short={len(short_docs)} long={len(long_docs)} epochs={args.ttt_epochs} batch={args.ttt_batch_size}") + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + t0 = time.perf_counter() + with torch.no_grad(): + for ds, dl in short_docs: + x = all_tokens[ds : ds + dl - 1].to(device=device, dtype=torch.int64).unsqueeze(0) + y = all_tokens[ds + 1 : ds + dl].to(device=device, dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + n = dl - 1 + loss_sum += loss.to(torch.float64) * n + token_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + byte_sum += tb.sum() + if master: + print(f"ttt:short_docs time={1000*(time.perf_counter()-t0):.0f}ms tokens={int(token_count.item())}") + + long_docs.sort(key=lambda d: (d[1] - 2) // args.ttt_chunk_size) + batch_size = args.ttt_batch_size + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + lora = BatchedTTTLoRA(batch_size, base_model, args.ttt_lora_rank, + lm_rank=args.ttt_lm_rank, tune_biases=args.ttt_bias_tune).to(device) + opt = _build_ttt_optimizer(lora, args) + t1 = time.perf_counter() + ttt_deadline = t1 + args.ttt_max_eval_secs + for bi in range(0, len(long_docs), batch_size): + if time.perf_counter() > ttt_deadline: + if master: + elapsed = 1000 * (time.perf_counter() - t1) + remaining = len(long_docs) - bi + print(f"ttt:TIME_LIMIT at batch {bi//batch_size+1}, time={elapsed:.0f}ms, base-scoring {remaining} remaining docs") + for rbi in range(bi, len(long_docs), batch_size): + rbatch = long_docs[rbi : rbi + batch_size] + for rb_idx, (ds, dl) in enumerate(rbatch): + pl = dl - 1 + toks = all_tokens[ds:ds+dl].to(dtype=torch.int64, device=device) + for ci_r in range((pl + chunk_size - 1) // chunk_size): + nc_r = (pl + chunk_size - 1) // chunk_size + ws, wl, co, cl = _compute_chunk_window(ci_r, pl, nc_r, chunk_size, eval_seq_len) + xt = toks[ws:ws+wl].unsqueeze(0) + yt = toks[ws+1:ws+wl+1].unsqueeze(0) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_b = base_model.get_logits(xt) + V = logits_b.size(-1) + ptl_b = F.cross_entropy(logits_b.float().reshape(-1, V), yt.reshape(-1), reduction='none').reshape(1, -1) + closs = ptl_b[0, co:co+cl].to(torch.float64) + if args.ttt_temp_rescale != 1.0: + closs = closs * args.ttt_temp_rescale + loss_sum += closs.sum() + token_count += cl + tgt_r = yt[0, co:co+cl]; px_r = xt[0, co:co+cl] + tb = base_bytes_lut[tgt_r].to(torch.float64) + tb += (has_leading_space_lut[tgt_r] & ~is_boundary_token_lut[px_r]).to(torch.float64) + byte_sum += tb.sum() + break + batch = long_docs[bi : bi + batch_size] + bsz = len(batch) + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, args.ttt_lora_rank, + lm_rank=args.ttt_lm_rank, tune_biases=args.ttt_bias_tune).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + pred_lens = [dl - 1 for _, dl in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + total_train_steps = args.ttt_epochs * max_nc + global_step = 0 + # V8: per-doc accumulators for score-every-epoch (overwrite each epoch) + doc_loss = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + doc_bytes = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + doc_toks = [0] * bsz + for epoch in range(args.ttt_epochs): + # V8: reset accumulators each epoch (overwrite with latest scores) + for b in range(bsz): + doc_loss[b].zero_(); doc_bytes[b].zero_(); doc_toks[b] = 0 + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + ws_ref, wl_ref, _, _ = _compute_chunk_window(ci, (ci+1)*chunk_size, ci+1, chunk_size, eval_seq_len) + x = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + y = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)); continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + toks = all_tokens[ds+ws : ds+ws+wl+1].to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1]; y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + needs_train = any(ci < nc-1 for nc in num_chunks) + if needs_train: + if args.ttt_cosine_lr and total_train_steps > 1: + cos_mul = 0.5 * (1.0 + math.cos(math.pi * global_step / total_train_steps)) + for g in cur_opt.param_groups: + g["lr"] = g.get("base_lr", g["lr"]) * max(cos_mul, 0.1) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + # V8: score EVERY epoch (accumulate into per-doc buffers, overwritten each epoch) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: continue + co, cl = doc_info[b] + # V8: apply post-TTT temperature rescaling + chunk_loss = ptl[b, co:co+cl].to(torch.float64) + if args.ttt_temp_rescale != 1.0: + chunk_loss = chunk_loss * args.ttt_temp_rescale + doc_loss[b] += chunk_loss.sum() + doc_toks[b] += cl + tgt = y[b, co:co+cl]; px = x[b, co:co+cl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + doc_bytes[b] += tb.sum() + if needs_train: + train_loss = torch.zeros(bsz, device=device) + for b in range(bsz): + if ci >= num_chunks[b]-1: continue + co, cl = doc_info[b] + if cl > 0: train_loss[b] = ptl[b, co:co+cl].mean() + cur_opt.zero_grad() + train_loss.sum().backward() + cur_opt.step() + global_step += 1 + # V8: add final epoch's scores to global accumulators + for b in range(bsz): + loss_sum += doc_loss[b] + token_count += doc_toks[b] + byte_sum += doc_bytes[b] + if master and (bi + batch_size) % (batch_size * 5) == 0: + elapsed = 1000 * (time.perf_counter() - t1) + avg_loss = loss_sum.item() / max(token_count.item(), 1) + print(f"ttt:batch {bi//batch_size+1}/{(len(long_docs)+batch_size-1)//batch_size} time={elapsed:.0f}ms avg_loss={avg_loss:.4f}") + if master: + print(f"ttt:long_docs time={1000*(time.perf_counter()-t1):.0f}ms docs={len(long_docs)}") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / max(token_count.item(), 1)) + val_bpb = float((loss_sum.item() / math.log(2.0)) / max(byte_sum.item(), 1)) + base_model.train() + for p in base_model.parameters(): + p.requires_grad_(True) + return val_loss, val_bpb + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files} val_tokens:{val_tokens.numel() - 1}") + + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=0.04, fused=True, + ) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=0.04) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=0.04, fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"seed:{args.seed} ema_enabled:{args.ema_enabled} ema_decay:{args.ema_decay} ema_every:{args.ema_every}") + log0(f"V10:ttt_time_limit ttt_rank:{args.ttt_lora_rank} lm:{args.ttt_lm_rank} lr:{args.ttt_lora_lr} cos:{args.ttt_cosine_lr} bias:{args.ttt_bias_tune} ep:{args.ttt_epochs} temp:{args.ttt_temp_rescale}") + + ema_state: dict[str, Tensor] = {} + _ema_updated = False + if args.ema_enabled: + for name, p in base_model.named_parameters(): + ema_state[name] = p.data.float().clone() + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + if args.ema_enabled: + for name, p in base_model.named_parameters(): + ema_state[name] = p.data.float().clone() + + training_time_ms = 0.0 + prev_log_ms = 0.0 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + wall_start = time.perf_counter() + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + if args.ema_enabled and step > 0 and step % args.ema_every == 0: + _ema_updated = True + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_state[name].lerp_(p.data.float(), 1.0 - args.ema_decay ** args.ema_every) + + if scale < 0.2 and step % 50 == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step={step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + mem_mb = torch.cuda.max_memory_allocated() // 1024 // 1024 + step_ms = (approx_training_time_ms - (training_time_ms if step <= 1 else 0)) / max(step, 1) + this_step_ms = approx_training_time_ms - prev_log_ms if step > 1 else approx_training_time_ms + prev_log_ms = approx_training_time_ms + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.6f} " + f"lr_scale:{scale:.4f} muon_mom:{muon_momentum:.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms " + f"this_step:{this_step_ms:.1f}ms mem:{mem_mb}MiB swa_n:{swa_count}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + train_wall_ms = 1000.0 * (time.perf_counter() - wall_start) + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0(f"phase:train wall_ms:{train_wall_ms:.0f} steps:{step} step_avg:{training_time_ms/max(step,1):.2f}ms") + phase_t = time.perf_counter() + + if swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + averaged = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(averaged, strict=True) + elif args.ema_enabled and _ema_updated: + log0("Applying EMA weights for export...") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_state: + p.data.copy_(ema_state[name].to(dtype=p.dtype, device=p.device)) + + with torch.no_grad(): + all_weights = [] + for name, p in base_model.named_parameters(): + if p.ndim == 2 and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + all_weights.append(p.data.abs().flatten()) + if all_weights: + all_abs = torch.cat(all_weights) + sample = all_abs[torch.randperm(len(all_abs), device=all_abs.device)[:min(1_000_000, len(all_abs))]] + idx = int(len(sample) * 0.04) # V6: 4% pruning for 16MB fit + threshold = float(sample.float().sort().values[idx].item()) + pruned = 0 + for name, p in base_model.named_parameters(): + if p.ndim == 2 and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + mask = p.data.abs() < threshold + pruned += mask.sum().item() + p.data[mask] = 0.0 + log0(f"pruning: zeroed {pruned:,} weights ({100*pruned/all_abs.numel():.1f}%) below {threshold:.6f}") + + log0(f"phase:postprocess wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f} (swa+ema+pruning)") + phase_t = time.perf_counter() + + torch.cuda.synchronize() + t_prequant = time.perf_counter() + prequant_loss, prequant_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"pre_quant_eval val_loss:{prequant_loss:.4f} val_bpb:{prequant_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_prequant):.0f}ms" + ) + log0(f"pre_quant_eval_exact val_loss:{prequant_loss:.8f} val_bpb:{prequant_bpb:.8f}") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + if master_process: + for name in sorted(quant_obj.get("quantized", {}).keys()): + q = quant_obj["quantized"][name] + s = quant_obj["scales"][name] + log0(f"quant_tensor:{name} shape:{list(q.shape)} bits:6 scale_range:[{s.float().min():.6f},{s.float().max():.6f}]") + for name in sorted(quant_obj.get("passthrough", {}).keys()): + t = quant_obj["passthrough"][name] + log0(f"passthrough_tensor:{name} shape:{list(t.shape)} dtype:{t.dtype} bytes:{t.numel() * t.element_size()}") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAVE_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_label = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + total_size = quant_file_bytes + code_bytes + log0(f"Total submission size {compress_label}: {total_size} bytes") + if total_size > 16_000_000: + log0(f"WARNING: Total size {total_size} exceeds 16MB limit!") + else: + log0(f"Size check PASSED: {total_size} / 16,000,000 ({100*total_size/16_000_000:.1f}%)") + + log0(f"phase:serialize wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f} (quant+compress+save)") + phase_t = time.perf_counter() + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if HAVE_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms " + f"eval_seq_len:{effective_eval_seq_len}" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + quant_gap_bpb = q_val_bpb - prequant_bpb + log0(f"quant_gap: {quant_gap_bpb:.6f} BPB (pre:{prequant_bpb:.6f} post:{q_val_bpb:.6f})") + log0(f"phase:postquant_eval wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f}") + phase_t = time.perf_counter() + + torch.cuda.synchronize() + torch._dynamo.reset() + ttt_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ).to(device) + ttt_model.load_state_dict(base_model.state_dict(), strict=True) + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, ttt_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms " + f"lora_rank:{args.ttt_lora_rank} chunk_size:{args.ttt_chunk_size}" + ) + log0(f"final_ttt_lora_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + ttt_gap_bpb = ttt_val_bpb - q_val_bpb + log0(f"ttt_gain: {-ttt_gap_bpb:.6f} BPB gain over int8 (int8:{q_val_bpb:.6f} ttt:{ttt_val_bpb:.6f})") + log0(f"phase:ttt_eval wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f}") + total_wall_ms = 1000.0 * (time.perf_counter() - wall_start) + log0(f"phase:TOTAL wall_ms:{total_wall_ms:.0f} ({total_wall_ms/60000:.1f} min)") + log0(f"phase_breakdown: train:{training_time_ms:.0f}ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_seed42.log b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_seed42.log new file mode 100644 index 0000000000..b64fed45e1 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_DeepQuant_V10b/train_seed42.log @@ -0,0 +1,362 @@ +W0324 11:39:49.037000 1325 torch/distributed/run.py:803] +W0324 11:39:49.037000 1325 torch/distributed/run.py:803] ***************************************** +W0324 11:39:49.037000 1325 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0324 11:39:49.037000 1325 torch/distributed/run.py:803] ***************************************** +logs/4669c65f-2366-41ff-bf4a-273fd55ad6d1.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +model_params:26829913 world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 ema_enabled:True ema_decay:0.999 ema_every:10 +V10:ttt_time_limit ttt_rank:8 lm:16 lr:0.01 cos:True bias:True ep:8 temp:0.98 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9307 val_bpb:4.1047 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.932050 lr_scale:1.0000 muon_mom:0.9200 train_time:135ms step_avg:135.18ms this_step:135.2ms mem:20869MiB swa_n:0 +step:2/20000 train_loss:8.088539 lr_scale:1.0000 muon_mom:0.9200 train_time:203ms step_avg:101.29ms this_step:67.4ms mem:20869MiB swa_n:0 +step:3/20000 train_loss:7.467353 lr_scale:1.0000 muon_mom:0.9201 train_time:286ms step_avg:95.22ms this_step:83.1ms mem:20869MiB swa_n:0 +step:4/20000 train_loss:6.933643 lr_scale:1.0000 muon_mom:0.9201 train_time:368ms step_avg:92.11ms this_step:82.8ms mem:20869MiB swa_n:0 +step:5/20000 train_loss:6.781602 lr_scale:1.0000 muon_mom:0.9202 train_time:452ms step_avg:90.31ms this_step:83.1ms mem:20869MiB swa_n:0 +step:6/20000 train_loss:6.822371 lr_scale:1.0000 muon_mom:0.9202 train_time:535ms step_avg:89.13ms this_step:83.2ms mem:20869MiB swa_n:0 +step:7/20000 train_loss:6.693643 lr_scale:1.0000 muon_mom:0.9203 train_time:618ms step_avg:88.22ms this_step:82.8ms mem:20869MiB swa_n:0 +step:8/20000 train_loss:6.602687 lr_scale:1.0000 muon_mom:0.9203 train_time:700ms step_avg:87.54ms this_step:82.7ms mem:20869MiB swa_n:0 +step:9/20000 train_loss:6.371422 lr_scale:1.0000 muon_mom:0.9204 train_time:783ms step_avg:87.00ms this_step:82.7ms mem:20869MiB swa_n:0 +step:10/20000 train_loss:6.102645 lr_scale:1.0000 muon_mom:0.9204 train_time:866ms step_avg:86.58ms this_step:82.7ms mem:20869MiB swa_n:0 +step:50/20000 train_loss:3.989717 lr_scale:1.0000 muon_mom:0.9223 train_time:4210ms step_avg:84.21ms this_step:3344.7ms mem:20869MiB swa_n:0 +step:100/20000 train_loss:3.245433 lr_scale:1.0000 muon_mom:0.9246 train_time:8397ms step_avg:83.97ms this_step:4186.9ms mem:20869MiB swa_n:0 +step:150/20000 train_loss:2.938554 lr_scale:1.0000 muon_mom:0.9270 train_time:12650ms step_avg:84.33ms this_step:4252.3ms mem:20869MiB swa_n:0 +step:200/20000 train_loss:2.457964 lr_scale:1.0000 muon_mom:0.9293 train_time:16847ms step_avg:84.24ms this_step:4197.8ms mem:20869MiB swa_n:0 +step:250/20000 train_loss:2.547057 lr_scale:1.0000 muon_mom:0.9316 train_time:21043ms step_avg:84.17ms this_step:4195.0ms mem:20869MiB swa_n:0 +step:300/20000 train_loss:2.621458 lr_scale:1.0000 muon_mom:0.9340 train_time:25300ms step_avg:84.33ms this_step:4257.5ms mem:20869MiB swa_n:0 +step:350/20000 train_loss:2.595742 lr_scale:1.0000 muon_mom:0.9363 train_time:29500ms step_avg:84.29ms this_step:4199.9ms mem:20869MiB swa_n:0 +step:400/20000 train_loss:2.476062 lr_scale:1.0000 muon_mom:0.9386 train_time:33771ms step_avg:84.43ms this_step:4270.6ms mem:20869MiB swa_n:0 +step:450/20000 train_loss:2.425850 lr_scale:1.0000 muon_mom:0.9410 train_time:37983ms step_avg:84.41ms this_step:4212.4ms mem:20869MiB swa_n:0 +step:500/20000 train_loss:2.451874 lr_scale:1.0000 muon_mom:0.9433 train_time:42202ms step_avg:84.40ms this_step:4218.8ms mem:20869MiB swa_n:0 +step:550/20000 train_loss:2.394425 lr_scale:1.0000 muon_mom:0.9456 train_time:46488ms step_avg:84.52ms this_step:4286.2ms mem:20869MiB swa_n:0 +step:600/20000 train_loss:2.383200 lr_scale:1.0000 muon_mom:0.9480 train_time:50712ms step_avg:84.52ms this_step:4224.2ms mem:20869MiB swa_n:0 +step:650/20000 train_loss:2.381544 lr_scale:1.0000 muon_mom:0.9503 train_time:54999ms step_avg:84.61ms this_step:4287.0ms mem:20869MiB swa_n:0 +step:700/20000 train_loss:2.394417 lr_scale:1.0000 muon_mom:0.9526 train_time:59221ms step_avg:84.60ms this_step:4221.7ms mem:20869MiB swa_n:0 +step:750/20000 train_loss:2.378147 lr_scale:1.0000 muon_mom:0.9550 train_time:63440ms step_avg:84.59ms this_step:4219.2ms mem:20869MiB swa_n:0 +step:800/20000 train_loss:2.287479 lr_scale:1.0000 muon_mom:0.9573 train_time:67726ms step_avg:84.66ms this_step:4286.2ms mem:20869MiB swa_n:0 +step:850/20000 train_loss:2.278646 lr_scale:1.0000 muon_mom:0.9596 train_time:71953ms step_avg:84.65ms this_step:4226.8ms mem:20869MiB swa_n:0 +step:900/20000 train_loss:2.175399 lr_scale:1.0000 muon_mom:0.9620 train_time:76230ms step_avg:84.70ms this_step:4277.0ms mem:20869MiB swa_n:0 +step:950/20000 train_loss:2.260240 lr_scale:1.0000 muon_mom:0.9643 train_time:80462ms step_avg:84.70ms this_step:4231.8ms mem:20869MiB swa_n:0 +step:1000/20000 train_loss:2.311006 lr_scale:1.0000 muon_mom:0.9666 train_time:84690ms step_avg:84.69ms this_step:4228.0ms mem:20869MiB swa_n:0 +step:1000/20000 val_loss:2.2728 val_bpb:1.3461 train_time:84708ms step_avg:84.71ms +step:1050/20000 train_loss:2.271102 lr_scale:1.0000 muon_mom:0.9690 train_time:88970ms step_avg:84.73ms this_step:4280.1ms mem:20869MiB swa_n:0 +step:1100/20000 train_loss:2.374232 lr_scale:1.0000 muon_mom:0.9713 train_time:93195ms step_avg:84.72ms this_step:4224.6ms mem:20869MiB swa_n:0 +step:1150/20000 train_loss:2.288929 lr_scale:1.0000 muon_mom:0.9736 train_time:97471ms step_avg:84.76ms this_step:4276.1ms mem:20869MiB swa_n:0 +step:1200/20000 train_loss:2.395080 lr_scale:1.0000 muon_mom:0.9760 train_time:101690ms step_avg:84.74ms this_step:4219.2ms mem:20869MiB swa_n:0 +step:1250/20000 train_loss:2.298902 lr_scale:1.0000 muon_mom:0.9783 train_time:105905ms step_avg:84.72ms this_step:4215.2ms mem:20869MiB swa_n:0 +step:1300/20000 train_loss:2.151644 lr_scale:1.0000 muon_mom:0.9806 train_time:110188ms step_avg:84.76ms this_step:4282.6ms mem:20869MiB swa_n:0 +step:1350/20000 train_loss:2.287394 lr_scale:1.0000 muon_mom:0.9830 train_time:114400ms step_avg:84.74ms this_step:4211.8ms mem:20869MiB swa_n:0 +step:1400/20000 train_loss:2.226420 lr_scale:1.0000 muon_mom:0.9853 train_time:118680ms step_avg:84.77ms this_step:4280.7ms mem:20869MiB swa_n:0 +step:1450/20000 train_loss:2.168962 lr_scale:1.0000 muon_mom:0.9876 train_time:122889ms step_avg:84.75ms this_step:4208.9ms mem:20869MiB swa_n:0 +step:1500/20000 train_loss:2.259071 lr_scale:1.0000 muon_mom:0.9900 train_time:127101ms step_avg:84.73ms this_step:4211.9ms mem:20869MiB swa_n:0 +step:1550/20000 train_loss:2.227993 lr_scale:1.0000 muon_mom:0.9900 train_time:131376ms step_avg:84.76ms this_step:4274.6ms mem:20869MiB swa_n:0 +step:1600/20000 train_loss:2.123164 lr_scale:1.0000 muon_mom:0.9900 train_time:135586ms step_avg:84.74ms this_step:4210.4ms mem:20869MiB swa_n:0 +step:1650/20000 train_loss:2.234782 lr_scale:1.0000 muon_mom:0.9900 train_time:139795ms step_avg:84.72ms this_step:4208.9ms mem:20869MiB swa_n:0 +step:1700/20000 train_loss:2.178277 lr_scale:1.0000 muon_mom:0.9900 train_time:144060ms step_avg:84.74ms this_step:4264.7ms mem:20869MiB swa_n:0 +step:1750/20000 train_loss:2.238895 lr_scale:1.0000 muon_mom:0.9900 train_time:148265ms step_avg:84.72ms this_step:4204.8ms mem:20869MiB swa_n:0 +step:1800/20000 train_loss:2.225036 lr_scale:1.0000 muon_mom:0.9900 train_time:152527ms step_avg:84.74ms this_step:4262.7ms mem:20869MiB swa_n:0 +step:1850/20000 train_loss:2.075745 lr_scale:1.0000 muon_mom:0.9900 train_time:156727ms step_avg:84.72ms this_step:4200.0ms mem:20869MiB swa_n:0 +step:1900/20000 train_loss:2.172472 lr_scale:1.0000 muon_mom:0.9900 train_time:160929ms step_avg:84.70ms this_step:4201.3ms mem:20869MiB swa_n:0 +step:1950/20000 train_loss:2.063821 lr_scale:1.0000 muon_mom:0.9900 train_time:165194ms step_avg:84.71ms this_step:4265.3ms mem:20869MiB swa_n:0 +step:2000/20000 train_loss:2.110958 lr_scale:1.0000 muon_mom:0.9900 train_time:169391ms step_avg:84.70ms this_step:4196.8ms mem:20869MiB swa_n:0 +step:2000/20000 val_loss:2.1730 val_bpb:1.2870 train_time:169408ms step_avg:84.70ms +step:2050/20000 train_loss:2.150226 lr_scale:1.0000 muon_mom:0.9900 train_time:173657ms step_avg:84.71ms this_step:4266.5ms mem:20869MiB swa_n:0 +step:2100/20000 train_loss:2.078981 lr_scale:1.0000 muon_mom:0.9900 train_time:177860ms step_avg:84.70ms this_step:4202.5ms mem:20869MiB swa_n:0 +step:2150/20000 train_loss:2.183601 lr_scale:1.0000 muon_mom:0.9900 train_time:182056ms step_avg:84.68ms this_step:4196.7ms mem:20869MiB swa_n:0 +step:2200/20000 train_loss:2.246216 lr_scale:1.0000 muon_mom:0.9900 train_time:186323ms step_avg:84.69ms this_step:4266.9ms mem:20869MiB swa_n:0 +step:2250/20000 train_loss:2.217416 lr_scale:1.0000 muon_mom:0.9900 train_time:190532ms step_avg:84.68ms this_step:4209.0ms mem:20869MiB swa_n:0 +step:2300/20000 train_loss:2.148679 lr_scale:1.0000 muon_mom:0.9900 train_time:194790ms step_avg:84.69ms this_step:4257.9ms mem:20869MiB swa_n:0 +step:2350/20000 train_loss:2.207604 lr_scale:1.0000 muon_mom:0.9900 train_time:198984ms step_avg:84.67ms this_step:4193.5ms mem:20869MiB swa_n:0 +step:2400/20000 train_loss:2.114476 lr_scale:1.0000 muon_mom:0.9900 train_time:203183ms step_avg:84.66ms this_step:4199.1ms mem:20869MiB swa_n:0 +step:2450/20000 train_loss:2.112900 lr_scale:1.0000 muon_mom:0.9900 train_time:207438ms step_avg:84.67ms this_step:4255.7ms mem:20869MiB swa_n:0 +step:2500/20000 train_loss:2.208804 lr_scale:1.0000 muon_mom:0.9900 train_time:211634ms step_avg:84.65ms this_step:4195.7ms mem:20869MiB swa_n:0 +step:2550/20000 train_loss:2.236876 lr_scale:1.0000 muon_mom:0.9900 train_time:215891ms step_avg:84.66ms this_step:4257.4ms mem:20869MiB swa_n:0 +step:2600/20000 train_loss:2.142518 lr_scale:1.0000 muon_mom:0.9900 train_time:220090ms step_avg:84.65ms this_step:4198.5ms mem:20869MiB swa_n:0 +step:2650/20000 train_loss:2.117440 lr_scale:1.0000 muon_mom:0.9900 train_time:224285ms step_avg:84.64ms this_step:4194.7ms mem:20869MiB swa_n:0 +step:2700/20000 train_loss:2.138550 lr_scale:1.0000 muon_mom:0.9900 train_time:228544ms step_avg:84.65ms this_step:4259.4ms mem:20869MiB swa_n:0 +step:2750/20000 train_loss:2.073166 lr_scale:1.0000 muon_mom:0.9900 train_time:232739ms step_avg:84.63ms this_step:4194.7ms mem:20869MiB swa_n:0 +step:2800/20000 train_loss:2.187673 lr_scale:1.0000 muon_mom:0.9900 train_time:236995ms step_avg:84.64ms this_step:4256.0ms mem:20869MiB swa_n:0 +step:2850/20000 train_loss:2.102222 lr_scale:1.0000 muon_mom:0.9900 train_time:241187ms step_avg:84.63ms this_step:4192.0ms mem:20869MiB swa_n:0 +step:2900/20000 train_loss:2.069113 lr_scale:1.0000 muon_mom:0.9900 train_time:245381ms step_avg:84.61ms this_step:4194.4ms mem:20869MiB swa_n:0 +step:2950/20000 train_loss:2.118033 lr_scale:1.0000 muon_mom:0.9900 train_time:249634ms step_avg:84.62ms this_step:4252.5ms mem:20869MiB swa_n:0 +step:3000/20000 train_loss:2.191947 lr_scale:1.0000 muon_mom:0.9900 train_time:253821ms step_avg:84.61ms this_step:4187.4ms mem:20869MiB swa_n:0 +step:3000/20000 val_loss:2.1297 val_bpb:1.2613 train_time:253839ms step_avg:84.61ms +step:3050/20000 train_loss:2.081064 lr_scale:1.0000 muon_mom:0.9900 train_time:258014ms step_avg:84.59ms this_step:4192.9ms mem:20869MiB swa_n:0 +step:3100/20000 train_loss:2.084753 lr_scale:1.0000 muon_mom:0.9900 train_time:262271ms step_avg:84.60ms this_step:4256.9ms mem:20869MiB swa_n:0 +step:3150/20000 train_loss:2.008487 lr_scale:1.0000 muon_mom:0.9900 train_time:266466ms step_avg:84.59ms this_step:4195.1ms mem:20869MiB swa_n:0 +step:3200/20000 train_loss:2.207227 lr_scale:1.0000 muon_mom:0.9900 train_time:270715ms step_avg:84.60ms this_step:4249.4ms mem:20869MiB swa_n:0 +step:3250/20000 train_loss:2.087616 lr_scale:1.0000 muon_mom:0.9900 train_time:274908ms step_avg:84.59ms this_step:4192.4ms mem:20869MiB swa_n:0 +step:3300/20000 train_loss:2.114355 lr_scale:1.0000 muon_mom:0.9900 train_time:279095ms step_avg:84.57ms this_step:4187.1ms mem:20869MiB swa_n:0 +step:3350/20000 train_loss:2.136599 lr_scale:1.0000 muon_mom:0.9900 train_time:283346ms step_avg:84.58ms this_step:4251.1ms mem:20869MiB swa_n:0 +step:3400/20000 train_loss:2.069345 lr_scale:1.0000 muon_mom:0.9900 train_time:287537ms step_avg:84.57ms this_step:4190.9ms mem:20869MiB swa_n:0 +step:3450/20000 train_loss:2.154311 lr_scale:1.0000 muon_mom:0.9900 train_time:291795ms step_avg:84.58ms this_step:4257.9ms mem:20869MiB swa_n:0 +step:3500/20000 train_loss:2.222590 lr_scale:1.0000 muon_mom:0.9900 train_time:295986ms step_avg:84.57ms this_step:4190.8ms mem:20869MiB swa_n:0 +step:3550/20000 train_loss:1.965108 lr_scale:1.0000 muon_mom:0.9900 train_time:300175ms step_avg:84.56ms this_step:4189.5ms mem:20869MiB swa_n:0 +step:3600/20000 train_loss:2.136110 lr_scale:1.0000 muon_mom:0.9900 train_time:304426ms step_avg:84.56ms this_step:4250.9ms mem:20869MiB swa_n:0 +step:3650/20000 train_loss:2.021913 lr_scale:1.0000 muon_mom:0.9900 train_time:308615ms step_avg:84.55ms this_step:4188.9ms mem:20869MiB swa_n:0 +step:3700/20000 train_loss:2.128757 lr_scale:1.0000 muon_mom:0.9900 train_time:312874ms step_avg:84.56ms this_step:4259.6ms mem:20869MiB swa_n:0 +step:3750/20000 train_loss:1.963294 lr_scale:1.0000 muon_mom:0.9900 train_time:317059ms step_avg:84.55ms this_step:4184.8ms mem:20869MiB swa_n:0 +step:3800/20000 train_loss:2.120957 lr_scale:1.0000 muon_mom:0.9900 train_time:321244ms step_avg:84.54ms this_step:4185.0ms mem:20869MiB swa_n:0 +step:3850/20000 train_loss:2.134960 lr_scale:1.0000 muon_mom:0.9900 train_time:325496ms step_avg:84.54ms this_step:4252.1ms mem:20869MiB swa_n:0 +step:3900/20000 train_loss:2.120189 lr_scale:1.0000 muon_mom:0.9900 train_time:329682ms step_avg:84.53ms this_step:4185.5ms mem:20869MiB swa_n:0 +step:3950/20000 train_loss:2.221283 lr_scale:1.0000 muon_mom:0.9900 train_time:333931ms step_avg:84.54ms this_step:4249.7ms mem:20869MiB swa_n:0 +step:4000/20000 train_loss:2.021319 lr_scale:1.0000 muon_mom:0.9900 train_time:338124ms step_avg:84.53ms this_step:4193.1ms mem:20869MiB swa_n:0 +step:4000/20000 val_loss:2.1151 val_bpb:1.2527 train_time:338142ms step_avg:84.54ms +step:4050/20000 train_loss:2.136159 lr_scale:1.0000 muon_mom:0.9900 train_time:342315ms step_avg:84.52ms this_step:4190.3ms mem:20869MiB swa_n:0 +step:4100/20000 train_loss:2.077119 lr_scale:0.9997 muon_mom:0.9900 train_time:346560ms step_avg:84.53ms this_step:4245.7ms mem:20869MiB swa_n:0 +step:4150/20000 train_loss:2.161564 lr_scale:0.9832 muon_mom:0.9900 train_time:350750ms step_avg:84.52ms this_step:4189.7ms mem:20869MiB swa_n:0 +step:4200/20000 train_loss:2.208965 lr_scale:0.9664 muon_mom:0.9900 train_time:355002ms step_avg:84.52ms this_step:4251.9ms mem:20869MiB swa_n:0 +step:4250/20000 train_loss:2.160754 lr_scale:0.9500 muon_mom:0.9900 train_time:359194ms step_avg:84.52ms this_step:4192.0ms mem:20869MiB swa_n:0 +step:4300/20000 train_loss:2.105979 lr_scale:0.9335 muon_mom:0.9900 train_time:363382ms step_avg:84.51ms this_step:4187.7ms mem:20869MiB swa_n:0 +step:4350/20000 train_loss:2.122095 lr_scale:0.9167 muon_mom:0.9900 train_time:367632ms step_avg:84.51ms this_step:4250.6ms mem:20869MiB swa_n:0 +step:4400/20000 train_loss:2.085918 lr_scale:0.9003 muon_mom:0.9900 train_time:371813ms step_avg:84.50ms this_step:4180.5ms mem:20869MiB swa_n:0 +step:4450/20000 train_loss:2.087721 lr_scale:0.8839 muon_mom:0.9900 train_time:376003ms step_avg:84.50ms this_step:4190.5ms mem:20869MiB swa_n:0 +step:4500/20000 train_loss:2.168918 lr_scale:0.8670 muon_mom:0.9900 train_time:380258ms step_avg:84.50ms this_step:4254.7ms mem:20869MiB swa_n:0 +step:4550/20000 train_loss:2.173985 lr_scale:0.8506 muon_mom:0.9900 train_time:384447ms step_avg:84.49ms this_step:4189.5ms mem:20869MiB swa_n:0 +step:4600/20000 train_loss:1.908979 lr_scale:0.8338 muon_mom:0.9900 train_time:388699ms step_avg:84.50ms this_step:4251.8ms mem:20869MiB swa_n:0 +step:4650/20000 train_loss:2.101929 lr_scale:0.8173 muon_mom:0.9900 train_time:392890ms step_avg:84.49ms this_step:4190.4ms mem:20869MiB swa_n:0 +step:4700/20000 train_loss:2.296495 lr_scale:0.8008 muon_mom:0.9900 train_time:397079ms step_avg:84.48ms this_step:4189.8ms mem:20869MiB swa_n:0 +step:4750/20000 train_loss:2.064267 lr_scale:0.7840 muon_mom:0.9900 train_time:401332ms step_avg:84.49ms this_step:4252.3ms mem:20869MiB swa_n:0 +step:4800/20000 train_loss:2.516044 lr_scale:0.7676 muon_mom:0.9900 train_time:405521ms step_avg:84.48ms this_step:4189.5ms mem:20869MiB swa_n:0 +step:4850/20000 train_loss:2.155927 lr_scale:0.7507 muon_mom:0.9900 train_time:409772ms step_avg:84.49ms this_step:4251.0ms mem:20869MiB swa_n:0 +step:4900/20000 train_loss:2.105861 lr_scale:0.7343 muon_mom:0.9900 train_time:413963ms step_avg:84.48ms this_step:4190.7ms mem:20869MiB swa_n:0 +step:4950/20000 train_loss:2.151264 lr_scale:0.7178 muon_mom:0.9900 train_time:418148ms step_avg:84.47ms this_step:4185.3ms mem:20869MiB swa_n:0 +step:5000/20000 train_loss:2.155752 lr_scale:0.7010 muon_mom:0.9900 train_time:422404ms step_avg:84.48ms this_step:4256.2ms mem:20869MiB swa_n:0 +step:5000/20000 val_loss:2.0745 val_bpb:1.2286 train_time:422421ms step_avg:84.48ms +step:5050/20000 train_loss:2.136290 lr_scale:0.6845 muon_mom:0.9900 train_time:426591ms step_avg:84.47ms this_step:4186.9ms mem:20869MiB swa_n:0 +step:5100/20000 train_loss:2.169608 lr_scale:0.6677 muon_mom:0.9900 train_time:430846ms step_avg:84.48ms this_step:4254.5ms mem:20869MiB swa_n:0 +step:5150/20000 train_loss:2.081245 lr_scale:0.6512 muon_mom:0.9900 train_time:435031ms step_avg:84.47ms this_step:4185.2ms mem:20869MiB swa_n:0 +step:5200/20000 train_loss:2.091791 lr_scale:0.6348 muon_mom:0.9900 train_time:439216ms step_avg:84.46ms this_step:4185.5ms mem:20869MiB swa_n:0 +step:5250/20000 train_loss:2.110512 lr_scale:0.6180 muon_mom:0.9900 train_time:443466ms step_avg:84.47ms this_step:4249.6ms mem:20869MiB swa_n:0 +step:5300/20000 train_loss:2.060823 lr_scale:0.6015 muon_mom:0.9900 train_time:447657ms step_avg:84.46ms this_step:4190.7ms mem:20869MiB swa_n:0 +step:5350/20000 train_loss:1.975796 lr_scale:0.5847 muon_mom:0.9900 train_time:451901ms step_avg:84.47ms this_step:4243.9ms mem:20869MiB swa_n:0 +step:5400/20000 train_loss:2.092291 lr_scale:0.5682 muon_mom:0.9900 train_time:456092ms step_avg:84.46ms this_step:4191.4ms mem:20869MiB swa_n:0 +step:5450/20000 train_loss:2.115972 lr_scale:0.5517 muon_mom:0.9900 train_time:460280ms step_avg:84.45ms this_step:4187.5ms mem:20869MiB swa_n:0 +step:5500/20000 train_loss:2.064779 lr_scale:0.5349 muon_mom:0.9900 train_time:464528ms step_avg:84.46ms this_step:4248.9ms mem:20869MiB swa_n:0 +step:5550/20000 train_loss:2.059327 lr_scale:0.5184 muon_mom:0.9900 train_time:468715ms step_avg:84.45ms this_step:4186.1ms mem:20869MiB swa_n:0 +step:5600/20000 train_loss:2.017942 lr_scale:0.5016 muon_mom:0.9900 train_time:472965ms step_avg:84.46ms this_step:4250.6ms mem:20869MiB swa_n:0 +step:5650/20000 train_loss:2.096813 lr_scale:0.4851 muon_mom:0.9900 train_time:477155ms step_avg:84.45ms this_step:4189.7ms mem:20869MiB swa_n:0 +step:5700/20000 train_loss:2.060323 lr_scale:0.4685 muon_mom:0.9900 train_time:481366ms step_avg:84.45ms this_step:4210.7ms mem:20869MiB swa_n:0 +step:5750/20000 train_loss:2.141077 lr_scale:0.4514 muon_mom:0.9900 train_time:485676ms step_avg:84.47ms this_step:4310.5ms mem:20869MiB swa_n:0 +step:5800/20000 train_loss:2.055765 lr_scale:0.4349 muon_mom:0.9900 train_time:489862ms step_avg:84.46ms this_step:4186.5ms mem:20869MiB swa_n:0 +step:5850/20000 train_loss:2.177721 lr_scale:0.4184 muon_mom:0.9900 train_time:494118ms step_avg:84.46ms this_step:4255.9ms mem:20869MiB swa_n:0 +step:5900/20000 train_loss:1.956624 lr_scale:0.4016 muon_mom:0.9900 train_time:498305ms step_avg:84.46ms this_step:4186.3ms mem:20869MiB swa_n:0 +step:5950/20000 train_loss:2.006535 lr_scale:0.3851 muon_mom:0.9900 train_time:502491ms step_avg:84.45ms this_step:4186.6ms mem:20869MiB swa_n:0 +step:6000/20000 train_loss:1.999923 lr_scale:0.3683 muon_mom:0.9900 train_time:506740ms step_avg:84.46ms this_step:4248.3ms mem:20869MiB swa_n:0 +step:6000/20000 val_loss:2.0309 val_bpb:1.2028 train_time:506759ms step_avg:84.46ms +step:6050/20000 train_loss:2.018607 lr_scale:0.3518 muon_mom:0.9900 train_time:510925ms step_avg:84.45ms this_step:4184.9ms mem:20869MiB swa_n:0 +step:6100/20000 train_loss:1.971310 lr_scale:0.3353 muon_mom:0.9900 train_time:515112ms step_avg:84.44ms this_step:4187.0ms mem:20869MiB swa_n:0 +step:6150/20000 train_loss:2.073521 lr_scale:0.3185 muon_mom:0.9900 train_time:519365ms step_avg:84.45ms this_step:4253.6ms mem:20869MiB swa_n:0 +step:6200/20000 train_loss:2.009248 lr_scale:0.3020 muon_mom:0.9900 train_time:523553ms step_avg:84.44ms this_step:4188.3ms mem:20869MiB swa_n:0 +step:6250/20000 train_loss:2.125503 lr_scale:0.2852 muon_mom:0.9900 train_time:527804ms step_avg:84.45ms this_step:4250.1ms mem:20869MiB swa_n:0 +step:6300/20000 train_loss:1.995234 lr_scale:0.2687 muon_mom:0.9900 train_time:531992ms step_avg:84.44ms this_step:4188.0ms mem:20869MiB swa_n:0 +step:6350/20000 train_loss:2.085248 lr_scale:0.2522 muon_mom:0.9900 train_time:536178ms step_avg:84.44ms this_step:4186.5ms mem:20869MiB swa_n:0 +step:6400/20000 train_loss:2.048195 lr_scale:0.2354 muon_mom:0.9900 train_time:540429ms step_avg:84.44ms this_step:4251.0ms mem:20869MiB swa_n:0 +step:6450/20000 train_loss:2.123784 lr_scale:0.2189 muon_mom:0.9900 train_time:544615ms step_avg:84.44ms this_step:4186.4ms mem:20869MiB swa_n:0 +step:6500/20000 train_loss:2.124338 lr_scale:0.2021 muon_mom:0.9900 train_time:548868ms step_avg:84.44ms this_step:4252.7ms mem:20869MiB swa_n:0 +step:6550/20000 train_loss:2.090365 lr_scale:0.1856 muon_mom:0.9900 train_time:553058ms step_avg:84.44ms this_step:4190.2ms mem:20869MiB swa_n:0 +swa:start step=6550 +step:6600/20000 train_loss:1.906521 lr_scale:0.1687 muon_mom:0.9900 train_time:557334ms step_avg:84.44ms this_step:4275.2ms mem:20869MiB swa_n:1 +step:6650/20000 train_loss:1.859928 lr_scale:0.1517 muon_mom:0.9900 train_time:561624ms step_avg:84.45ms this_step:4290.1ms mem:20869MiB swa_n:2 +step:6700/20000 train_loss:1.991382 lr_scale:0.1351 muon_mom:0.9900 train_time:565840ms step_avg:84.45ms this_step:4216.0ms mem:20869MiB swa_n:3 +step:6750/20000 train_loss:2.137290 lr_scale:0.1182 muon_mom:0.9900 train_time:570114ms step_avg:84.46ms this_step:4274.1ms mem:20869MiB swa_n:4 +step:6800/20000 train_loss:2.063745 lr_scale:0.1015 muon_mom:0.9900 train_time:574337ms step_avg:84.46ms this_step:4223.4ms mem:20869MiB swa_n:5 +step:6850/20000 train_loss:1.878264 lr_scale:0.0849 muon_mom:0.9900 train_time:578564ms step_avg:84.46ms this_step:4226.7ms mem:20869MiB swa_n:6 +step:6900/20000 train_loss:1.875529 lr_scale:0.0680 muon_mom:0.9900 train_time:582841ms step_avg:84.47ms this_step:4277.2ms mem:20869MiB swa_n:7 +step:6950/20000 train_loss:2.003772 lr_scale:0.0514 muon_mom:0.9900 train_time:587054ms step_avg:84.47ms this_step:4212.9ms mem:20869MiB swa_n:8 +step:7000/20000 train_loss:1.847851 lr_scale:0.0345 muon_mom:0.9900 train_time:591328ms step_avg:84.48ms this_step:4274.5ms mem:20869MiB swa_n:9 +step:7000/20000 val_loss:1.9779 val_bpb:1.1714 train_time:591345ms step_avg:84.48ms +step:7050/20000 train_loss:1.924878 lr_scale:0.0179 muon_mom:0.9900 train_time:595541ms step_avg:84.47ms this_step:4212.4ms mem:20869MiB swa_n:10 +step:7100/20000 train_loss:1.980256 lr_scale:0.0012 muon_mom:0.9900 train_time:599759ms step_avg:84.47ms this_step:4218.0ms mem:20869MiB swa_n:11 +step:7103/20000 val_loss:1.9751 val_bpb:1.1697 train_time:600074ms step_avg:84.48ms +stopping_early: wallclock_cap train_time:600074ms step:7103/20000 +peak memory allocated: 20869 MiB reserved: 20910 MiB +phase:train wall_ms:649386 steps:7103 step_avg:84.48ms +swa:applying averaged 12 checkpoints +pruning: zeroed 1,065,744 weights (4.0%) below 0.005523 +phase:postprocess wall_ms:144 (swa+ema+pruning) +pre_quant_eval val_loss:1.9635 val_bpb:1.1629 eval_time:44735ms +pre_quant_eval_exact val_loss:1.96347415 val_bpb:1.16287999 +Serialized model: 105792597 bytes +Code size: 71083 bytes +Total submission size: 105863680 bytes +quant_tensor:bigram.embed.weight shape:[2048, 128] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.056610] +quant_tensor:blocks.0.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.034088] +quant_tensor:blocks.0.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.044281] +quant_tensor:blocks.0.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.086975] +quant_tensor:blocks.1.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.037659] +quant_tensor:blocks.1.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032745] +quant_tensor:blocks.1.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035034] +quant_tensor:blocks.1.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.063293] +quant_tensor:blocks.10.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039398] +quant_tensor:blocks.10.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033661] +quant_tensor:blocks.10.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.037445] +quant_tensor:blocks.10.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.051117] +quant_tensor:blocks.10.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.136841] +quant_tensor:blocks.2.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.036072] +quant_tensor:blocks.2.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.084106] +quant_tensor:blocks.2.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.154907] +quant_tensor:blocks.3.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.050568] +quant_tensor:blocks.3.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.034454] +quant_tensor:blocks.3.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032471] +quant_tensor:blocks.3.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039673] +quant_tensor:blocks.4.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.035736] +quant_tensor:blocks.4.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034851] +quant_tensor:blocks.4.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032471] +quant_tensor:blocks.4.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.034668] +quant_tensor:blocks.4.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.035645] +quant_tensor:blocks.5.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033386] +quant_tensor:blocks.5.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.037109] +quant_tensor:blocks.5.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.037415] +quant_tensor:blocks.6.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.046692] +quant_tensor:blocks.6.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034119] +quant_tensor:blocks.6.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.033020] +quant_tensor:blocks.6.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.047272] +quant_tensor:blocks.7.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.035980] +quant_tensor:blocks.7.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032959] +quant_tensor:blocks.7.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.036407] +quant_tensor:blocks.7.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.053619] +quant_tensor:blocks.8.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034515] +quant_tensor:blocks.8.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035858] +quant_tensor:blocks.8.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.054596] +quant_tensor:blocks.9.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.040253] +quant_tensor:blocks.9.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033569] +quant_tensor:blocks.9.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +passthrough_tensor:bigram.proj.weight shape:[512, 128] dtype:torch.float16 bytes:131072 +passthrough_tensor:bigram.scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.0.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.1.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.1.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.1.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.10.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.10.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.10.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.2.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.2.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.2.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.3.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.3.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.3.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.4.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.4.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.4.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.5.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.5.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.5.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.6.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.6.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.6.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.7.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.7.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.7.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.8.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.8.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.8.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.9.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.9.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.9.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:skip_weights shape:[5, 512] dtype:torch.float32 bytes:10240 +passthrough_tensor:smear.gate shape:[512] dtype:torch.float16 bytes:1024 +passthrough_tensor:tok_emb.weight shape:[1024, 512] dtype:torch.float16 bytes:1048576 +Serialized model zstd-22: 15392872 bytes (payload:27578744 raw_torch:27638331 payload_ratio:3.83x) +Total submission size zstd-22: 15463955 bytes +Size check PASSED: 15463955 / 16,000,000 (96.6%) +phase:serialize wall_ms:67871 (quant+compress+save) +final_int8_zlib_roundtrip val_loss:1.9843 val_bpb:1.1752 eval_time:2192ms eval_seq_len:2048 +final_int8_zlib_roundtrip_exact val_loss:1.98431408 val_bpb:1.17522257 +quant_gap: 0.012343 BPB (pre:1.162880 post:1.175223) +phase:postquant_eval wall_ms:2981 +ttt:rank0 short=2294 long=3956 epochs=8 batch=64 +ttt:short_docs time=21759ms tokens=698809 +ttt:batch 5/62 time=7553ms avg_loss=1.8383 +ttt:batch 10/62 time=14991ms avg_loss=1.7174 +ttt:batch 15/62 time=23394ms avg_loss=1.6293 +ttt:batch 20/62 time=36261ms avg_loss=1.4984 +ttt:batch 25/62 time=49060ms avg_loss=1.4126 +ttt:batch 30/62 time=67997ms avg_loss=1.3174 +ttt:batch 35/62 time=88258ms avg_loss=1.2434 +ttt:batch 40/62 time=113520ms avg_loss=1.1723 +ttt:batch 45/62 time=146105ms avg_loss=1.1108 +ttt:batch 50/62 time=187520ms avg_loss=1.0519 +ttt:batch 55/62 time=242412ms avg_loss=1.0038 +ttt:batch 60/62 time=342053ms avg_loss=0.9553 +ttt:long_docs time=558962ms docs=3956 +final_ttt_lora val_loss:0.9878 val_bpb:0.5850 eval_time:581205ms lora_rank:8 chunk_size:256 +final_ttt_lora_exact val_loss:0.98782486 val_bpb:0.58504549 +ttt_gain: 0.590177 BPB gain over int8 (int8:1.175223 ttt:0.585045) +phase:ttt_eval wall_ms:581931 +phase:TOTAL wall_ms:1302313 (21.7 min) +phase_breakdown: train:600074ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above