diff --git a/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/README.md b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/README.md new file mode 100644 index 0000000000..c33ce04a20 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/README.md @@ -0,0 +1,124 @@ +# Compressor-Aware Training (CAT) + +Non-record submission, 1xH100. Submitting for the technique, not the score. + +**val_bpb:** 1.4465 (int8+zlib roundtrip) | **Artifact:** 11.48 MB | **Hardware:** 1xH100 80GB, 600s + +## Why I did this + +I'm a data scientist. I work mostly with Bayesian stats and causal inference, not language models. I entered Parameter Golf because something about the setup bothered me: everyone compresses their artifact, but training is completely indifferent to compression. The model doesn't know or care that its weights are about to be quantized and compressed. It just optimizes for prediction quality, and compression is an afterthought. + +That felt like a missed opportunity. The competition is really a two-level compression problem. Your model compresses text (the BPB score). Then the model itself gets compressed (int8 + zlib, has to fit in 16MB). These two things interact, but everyone treats them as separate steps. In Bayesian model selection, this is the Minimum Description Length principle: minimize the total cost of describing the model plus the data given the model. The 16MB cap is the first term. BPB is the second. Nobody's jointly optimizing them. + +So I spent a few days (when I had free time) figuring out whether you could actually train a model to produce weights that compress better. Not for any specific compressor, but for the general family of compressors that most tools use (zlib, zstd, brotli all share the same basic structure). Turns out nobody's tried it. + +## How compression works (and why training ignores half of it) + +Most compressors (zlib, zstd, brotli) use some variant of the same two-stage pipeline: + +1. **Dictionary matching (LZ77):** scan through the bytes looking for repeated sequences. When you find one, replace it with a pointer back to the first occurrence. More and longer repeats = smaller output. + +2. **Entropy coding (Huffman / FSE):** take whatever's left and assign shorter codes to byte values that appear more often. Concentrated value distributions = fewer bits. + +Every paper I found on compression-aware training (Wiedemann 2018, CERWU 2025, Deep Compression 2015) uses Shannon entropy as the proxy for "how compressible are these weights." Shannon entropy is a good proxy for the entropy coding stage. It tells you nothing about dictionary matching. + +Here's what convinced me this matters: I made two arrays with the exact same value distribution, identical histogram, identical Shannon entropy. One was a smooth wave, the other was the same values shuffled randomly. zlib compressed the smooth one 3x smaller. Same entropy, totally different compressed size. Dictionary matching was doing most of the work, and nobody was accounting for it. + +## What I built + +Two differentiable loss terms that approximate what a typical two-stage compressor does: + +``` +L_total = L_language_model + lambda_lz * L_dictionary_match + lambda_h * L_entropy +``` + +**The dictionary matching proxy** measures how similar nearby bytes are in the serialized weight stream. I compute a soft match score at power-of-2 lag distances (1, 2, 4, ... 512 bytes apart): + +```python +for lag in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]: + diff_sq = (byte_stream[lag:] - byte_stream[:-lag]).square() + match_score += torch.exp(-diff_sq / temperature).mean() +``` + +`exp(-diff^2/T)` is ~1 when two bytes are the same, ~0 when they're different. It's a smooth, differentiable version of "do these bytes match?" The gradient tells each weight which direction to move to create more repeated patterns. This works for any LZ-family compressor, not just zlib. + +**The entropy proxy** builds a soft histogram of byte values using a Gaussian kernel, then computes Shannon entropy of that histogram. This is a standard technique. I included it because it covers the entropy coding half. + +Both losses backpropagate through the quantization step using the straight-through estimator (STE). The STE is a trick where the forward pass does real int8 rounding, but the backward pass pretends the rounding didn't happen so gradients can flow through. + +## What happened + +### Debugging the dictionary matching proxy + +The first version didn't work at all. I normalized byte values to [0, 1] before computing soft matches, which made all differences tiny. With temperature=1.0, even random bytes scored 0.89 similarity. The proxy couldn't tell structured data from noise. + +Fixed it by using raw byte values [0-255] and temperature=50. At T=50, structured data scores 0.99 while random data scores 0.05. 21x discrimination, enough to give a useful gradient signal. + +### Another bug on the GPU + +First 8xH100 run crashed silently. torchrun swallowed the traceback. Had to run it with plain `python3` to see the actual error: `torch.Generator` on CUDA can't be used with `torch.randint`. Cost about $14 in wasted GPU time before I figured it out. + +### 1xH100 results (5 runs, 600s each) + +| Run | Config | BPB | Artifact | vs Control | +|-----|--------|-----|----------|-----------| +| Control | no CAT | 1.4374 | 12.32 MB | -- | +| Dict. match only | lz=0.01 | 1.4463 | 12.15 MB | -173 KB | +| Entropy only | h=0.1 | 1.4465 | 11.52 MB | -808 KB | +| Combined | lz=0.01, h=0.1 | 1.4465 | 11.48 MB | -842 KB | +| Entropy strong | h=1.0 | 1.5044 | 9.81 MB | -2.52 MB | + +The entropy proxy does most of the work. At lambda=0.1 it saves 808 KB (6.6%) for only +0.009 BPB. The dictionary matching proxy adds 173 KB on its own and 34 KB on top of entropy. Smaller effect, but real. + +At lambda=1.0 the entropy proxy saves 2.52 MB (20%) but BPB takes a bigger hit (+0.067). There's a tradeoff you can dial wherever you want. + +## What this means + +The control run produced a 12.32 MB artifact. CAT combined brought that down to 11.48 MB. That's 842 KB freed up. + +842 KB is roughly 842K extra parameters in int8. For a dim=896 model, that's about one extra attention layer's worth of capacity. Whether that extra capacity offsets the 0.009 BPB cost is the open question. I didn't have enough compute budget to test the "reinvest the saved bytes into a wider model" experiment. + +The entropy-strong run is more dramatic: 9.81 MB leaves 6.2 MB of headroom under the 16MB limit. That's a lot of room for a bigger model. + +## What's new here + +I searched for prior work on training neural network weights to be friendly to compression. Specifically for dictionary matching (the spatial pattern part, not just value distributions). + +I didn't find anything. The closest things: + +- Wiedemann 2018, CERWU 2025: use Shannon entropy (covers entropy coding only, misses dictionary matching) +- Deep Compression 2015: applies Huffman after training, not during +- Sandwiched Compression (Google 2024): differentiable proxy of a fixed codec, but for images going through JPEG, not neural net weights +- NuMuon 2026: nuclear norm constraints happen to help zstd (low rank = repeated patterns), but it's not designed for compression + +The dictionary matching proxy via multi-lag autocorrelation is the new part. The entropy proxy is established. The combination, applied to neural network weight compression during training and targeting the LZ-family compressor structure that most tools share, is what I haven't seen before. + +## Architecture + +4 physical transformer layers looped 3 times (12 effective layers), with per-loop LoRA adapters (rank 16). Based on the Relaxed Recursive Transformers paper (Bae et al. 2024). dim=896, 14 attention heads, 2 KV heads (GQA). QAT fused with LR cooldown. + +## Compute + +Total spend: ~$18 across two sessions. + +- March 18-19: Built the depth-recurrent architecture, ran 28 experiments. ~$15. +- April 4-5: Designed CAT, ran 20+ local experiments on MLX (free) and 6 H100 runs. ~$3. + +## What I'd do with more compute + +1. Train a wider model (dim=1024+) using the bytes saved by CAT, check if net BPB improves +2. Test CAT with int6 quantization (what the leaders use) +3. Run on 8xH100 for 3000+ steps, the compression effect compounds during training +4. Sweep dictionary matching temperature, maybe sharper matching helps + +## References + +- RFC 1951: DEFLATE specification +- Wiedemann et al. 2018. Entropy-Constrained Training. arXiv:1812.07520 +- Conzelmann & Bamler 2025. CERWU. arXiv:2505.18758 +- Han, Mao, Dally 2015. Deep Compression. arXiv:1510.00149 +- Ullrich, Meeds, Welling 2017. Soft Weight-Sharing. arXiv:1702.04008 +- Google 2024. Sandwiched Compression. arXiv:2402.05887 +- Deletang et al. 2024. Language Modeling Is Compression. arXiv:2309.10668 +- Bae et al. 2024. Relaxed Recursive Transformers. arXiv:2410.20672 +- Bengio et al. 2013. Estimating Gradients Through Stochastic Neurons. arXiv:1308.3432 diff --git a/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/submission.json b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/submission.json new file mode 100644 index 0000000000..7adad855d4 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/submission.json @@ -0,0 +1,23 @@ +{ + "author": "Tomas Korenblit", + "github_id": "korentomas", + "val_bpb": 1.4465, + "pre_quant_val_bpb": 1.4461, + "artifact_bytes": 11480222, + "hardware": "1xH100 80GB HBM3", + "training_time_seconds": 600, + "steps": 590, + "technique": "Compressor-Aware Training (CAT): differentiable LZ77 autocorrelation + entropy proxy regularizer for zlib-friendly weights", + "architecture": { + "num_physical_layers": 4, + "num_loops": 3, + "model_dim": 896, + "num_heads": 14, + "num_kv_heads": 2, + "lora_rank": 16, + "qat_mode": 2, + "cat_lambda_lz": 0.01, + "cat_lambda_h": 0.1, + "cat_temperature": 50.0 + } +} diff --git a/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/train_golf.py b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/train_golf.py new file mode 100644 index 0000000000..413743ac97 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_CompressorAwareTraining_CAT/train_golf.py @@ -0,0 +1,1485 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden_override = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = auto from dim*mlp_mult + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + # Depth recurrence + num_physical_layers = int(os.environ.get("NUM_PHYSICAL_LAYERS", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + # QAT: 0=off, 1=always, 2=fused with cooldown + qat_mode = int(os.environ.get("QAT_MODE", 0)) + + # Compressor-Aware Training (CAT) + cat_lambda_lz = float(os.environ.get("CAT_LAMBDA_LZ", 0.0)) + cat_lambda_h = float(os.environ.get("CAT_LAMBDA_H", 0.0)) + cat_temperature = float(os.environ.get("CAT_TEMPERATURE", 50.0)) + cat_bandwidth = float(os.environ.get("CAT_BANDWIDTH", 1.0)) + cat_sample_size = int(os.environ.get("CAT_SAMPLE_SIZE", 100_000)) + cat_start_step = int(os.environ.get("CAT_START_STEP", 0)) + cat_log_every = int(os.environ.get("CAT_LOG_EVERY", 50)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +BOS_ID = 1 +STRIDE_CHUNK = int(os.environ.get("STRIDE_CHUNK", 0)) + + +def strided_val( + args, model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + chunk_size=256, window_size=1024, +): + """Document-isolated strided validation (PR #77 approach, no TTT). + Each token scored with max context via overlapping windows.""" + + bos_pos = (val_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + docs = [] + for i in range(len(bos_pos)): + s = bos_pos[i] + e = bos_pos[i + 1] if i + 1 < len(bos_pos) else val_tokens.numel() + if e - s >= 2: + docs.append((s, e - s)) + + my_docs = docs[(len(docs) * rank) // world_size: (len(docs) * (rank + 1)) // world_size] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() # noqa: model.eval is not Python's eval() + with torch.inference_mode(): + for doc_start, doc_len in my_docs: + pred_len = doc_len - 1 + n_chunks = (pred_len + chunk_size - 1) // chunk_size + + for ci in range(n_chunks): + cs = ci * chunk_size + ce = pred_len if ci == n_chunks - 1 else (ci + 1) * chunk_size + cl = ce - cs + ws = max(0, ce - window_size) + wl = ce - ws + co = cs - ws + + raw = val_tokens[doc_start + ws: doc_start + ws + wl + 1] + raw = raw.to(device=device, dtype=torch.int64) + x = raw[:-1].unsqueeze(0) + y = raw[1:].unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ptl = model.per_token_loss(x, y) + + chunk_losses = ptl[0, co:co + cl].to(torch.float64) + cx = x[0, co:co + cl] + cy = y[0, co:co + cl] + + tb = base_bytes_lut[cy].to(torch.float64) + tb += (has_leading_space_lut[cy] & ~is_boundary_token_lut[cx]).to(torch.float64) + + loss_sum += chunk_losses.sum() + byte_sum += tb.sum() + tok_count += cl + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + + vl = float((loss_sum / tok_count).item()) + vb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + model.train() # noqa + return vl, vb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + qmax = (1 << (QUANT_BITS - 1)) - 1 # 31 for int6, 127 for int8 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + # Optionally keep tok_emb in fp16 (set FP16_EMBED=1). Helps baseline but + # costs ~4MB; not worth it when QAT already handles the quant gap. + fp16_embed = bool(int(os.environ.get("FP16_EMBED", "0"))) + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or (fp16_embed and "tok_emb" in name): + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +_QAT_ENABLED = False + +QUANT_BITS = int(os.environ.get("QUANT_BITS", 8)) # 6 or 8 + +def fake_quantize_per_row(w: Tensor) -> Tensor: + """Simulate per-row quantization with STE. Supports int6 or int8.""" + qmax = (1 << (QUANT_BITS - 1)) - 1 # 31 for int6, 127 for int8 + w_f32 = w.float() + row_max = w_f32.abs().amax(dim=1, keepdim=True) + scale = torch.clamp(row_max / qmax, min=1.0 / qmax) + q = torch.clamp(torch.round(w_f32 / scale), -qmax, qmax) + w_dq = q * scale + return (w_dq - w_f32).detach() + w_f32 # STE + + +def serialize_quantized_weights_torch(model: nn.Module) -> torch.Tensor: + """Serialize fake-quantized weights to flat float tensor of byte values. + + Uses STE (straight-through estimator) so gradients flow through round(). + Without STE, torch.round() has zero gradient and CAT loss cannot update weights. + Respects QUANT_BITS so the proxy matches the actual quantization grid. + Output range is [1, 2*qmax+1] (e.g. [1, 255] for int8, [1, 63] for int6). + Zero never appears because quantized values are in [-qmax, qmax] shifted by qmax+1. + """ + qmax = (1 << (QUANT_BITS - 1)) - 1 # 31 for int6, 127 for int8 + byte_chunks = [] + for name, param in model.named_parameters(): + if param.ndim < 2 or param.numel() <= 65536: + continue + w = param.float() + row_max = w.abs().amax(dim=-1, keepdim=True) + scale = torch.clamp(row_max / float(qmax), min=1.0 / float(qmax)) + w_clamped = torch.clamp(w / scale, -qmax, qmax) + w_rounded = torch.round(w_clamped) + # STE: forward uses rounded values, backward flows through w_clamped + w_q = (w_rounded - w_clamped).detach() + w_clamped + w_bytes = w_q + float(qmax + 1) # shift to unsigned range + byte_chunks.append(w_bytes.reshape(-1)) + return torch.cat(byte_chunks) + + +def lz77_proxy_loss_torch(byte_stream: torch.Tensor, temperature: float = 50.0, + sample_size: int = 100_000, seed: int = 0) -> torch.Tensor: + """Differentiable LZ77 proxy for PyTorch. + + Uses RAW byte values [0-255] (no normalization). Temperature should be + in range 10-500. At T=50, structured data scores ~0.99, random ~0.05. + Seed ensures deterministic subsampling across DDP ranks. + """ + if byte_stream.shape[0] > sample_size: + g = torch.Generator(device="cpu").manual_seed(seed) + start = torch.randint(0, byte_stream.shape[0] - sample_size, (1,), generator=g).item() + x = byte_stream[start:start + sample_size] + else: + x = byte_stream + # No normalization — raw byte values [0-255] with appropriately scaled temperature + lags = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + match_score = torch.tensor(0.0, device=x.device) + num_lags = 0 + for lag in lags: + if lag >= x.shape[0]: + break + diff_sq = (x[lag:] - x[:-lag]).square() + match_score = match_score + torch.exp(-diff_sq / temperature).mean() + num_lags += 1 + if num_lags == 0: + return match_score + return -match_score / float(num_lags) + + +def entropy_proxy_loss_torch(byte_stream: torch.Tensor, bandwidth: float = 1.0, + sample_size: int = 100_000, seed: int = 0) -> torch.Tensor: + """Differentiable entropy proxy for PyTorch. + Seed ensures deterministic subsampling across DDP ranks. + """ + if byte_stream.shape[0] > sample_size: + g = torch.Generator(device="cpu").manual_seed(seed) + indices = torch.randint(0, byte_stream.shape[0], (sample_size,), generator=g).to(byte_stream.device) + x = byte_stream[indices] + else: + x = byte_stream + num_bins = 64 + bin_size = 256.0 / num_bins + centers = torch.arange(num_bins, dtype=torch.float32, device=x.device) * bin_size + bin_size / 2.0 + logits = -(x[:, None] - centers[None, :]).square() / (bandwidth * bin_size) + soft_hist = logits.softmax(dim=-1).mean(dim=0) + entropy = -(soft_hist * (soft_hist + 1e-10).log()).sum() + return entropy + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if _QAT_ENABLED and w.ndim == 2 and w.numel() > 65536: + w = fake_quantize_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class LoRAAdapter(nn.Module): + """Low-rank adapter for per-loop specialization.""" + def __init__(self, in_dim: int, out_dim: int, rank: int): + super().__init__() + self.A = nn.Parameter(torch.randn(in_dim, rank) * (1.0 / math.sqrt(in_dim))) + self.B = nn.Parameter(torch.zeros(rank, out_dim)) + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.to(x.dtype)) @ self.B.to(x.dtype) + + +class LoRASet(nn.Module): + """Per-loop LoRA adapters for Q, K, V projections.""" + def __init__(self, dim: int, head_dim: int, num_heads: int, num_kv_heads: int, rank: int): + super().__init__() + self.lora_q = LoRAAdapter(dim, num_heads * head_dim, rank) + self.lora_k = LoRAAdapter(dim, num_kv_heads * head_dim, rank) + self.lora_v = LoRAAdapter(dim, num_kv_heads * head_dim, rank) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, lora: LoRASet | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if lora is not None: + q = q + lora.lora_q(x) + k = k + lora.lora_k(x) + v = v + lora.lora_v(x) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads for GQA compatibility (torch <2.5 lacks enable_gqa) + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(repeats, dim=1) + v = v.repeat_interleave(repeats, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # SwiGLU MLP: gate + up projections, gated by silu, then down projection. + def __init__(self, dim: int, mlp_mult: int, hidden_override: int = 0): + super().__init__() + if hidden_override > 0: + hidden = hidden_override + else: + hidden = int(2 * dim * mlp_mult / 3) + hidden = ((hidden + 63) // 64) * 64 # round to 64 for efficiency + self.gate = CastedLinear(dim, hidden, bias=False) + self.up = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.silu(self.gate(x)) * self.up(x)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden_override: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, hidden_override=mlp_hidden_override) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, lora: LoRASet | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), lora=lora) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_physical_layers: int = 0, + num_loops: int = 1, + mlp_hidden_override: int = 0, + lora_rank: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.use_recurrence = num_physical_layers > 0 and num_loops > 1 + self.tok_emb = nn.Embedding(vocab_size, model_dim) + + if self.use_recurrence: + self.num_physical = num_physical_layers + self.num_loops = num_loops + effective_layers = num_physical_layers * num_loops + self.num_encoder_layers = effective_layers // 2 + self.num_decoder_layers = effective_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, mlp_hidden_override=mlp_hidden_override) + for _ in range(num_physical_layers) + ]) + head_dim = model_dim // num_heads + if lora_rank > 0: + self.lora_sets = nn.ModuleList([ + nn.ModuleList([ + LoRASet(model_dim, head_dim, num_heads, num_kv_heads, lora_rank) + for _ in range(num_physical_layers) + ]) + for _ in range(num_loops - 1) + ]) + else: + self.lora_sets = nn.ModuleList() + self.loop_embeds = nn.Parameter(torch.zeros(num_loops, model_dim, dtype=torch.float32)) + else: + self.num_physical = num_layers + self.num_loops = 1 + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, mlp_hidden_override=mlp_hidden_override) + for i in range(num_layers) + ]) + self.lora_sets = nn.ModuleList() + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + if self.use_recurrence: + effective_idx = 0 + for loop_i in range(self.num_loops): + x = x + self.loop_embeds[loop_i].to(dtype=x.dtype)[None, None, :] + for phys_i in range(self.num_physical): + lora = None + if loop_i > 0 and len(self.lora_sets) > 0: + lora = self.lora_sets[loop_i - 1][phys_i] + if effective_idx < self.num_encoder_layers: + x = self.blocks[phys_i](x, x0, lora=lora) + skips.append(x) + else: + dec_i = effective_idx - self.num_encoder_layers + if skips: + x = x + self.skip_weights[dec_i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[phys_i](x, x0, lora=lora) + effective_idx += 1 + else: + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + bsz, seqlen, dim = x.shape + x = self.final_norm(x).reshape(-1, dim) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + @torch.no_grad() + def per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Per-token loss for strided eval. NOT compiled — called only at final eval.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + if self.use_recurrence: + effective_idx = 0 + for loop_i in range(self.num_loops): + x = x + self.loop_embeds[loop_i].to(dtype=x.dtype)[None, None, :] + for phys_i in range(self.num_physical): + lora = None + if loop_i > 0 and len(self.lora_sets) > 0: + lora = self.lora_sets[loop_i - 1][phys_i] + if effective_idx < self.num_encoder_layers: + x = self.blocks[phys_i](x, x0, lora=lora) + skips.append(x) + else: + dec_i = effective_idx - self.num_encoder_layers + if skips: + x = x + self.skip_weights[dec_i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[phys_i](x, x0, lora=lora) + effective_idx += 1 + else: + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + bsz, seqlen, dim = x.shape + x = self.final_norm(x).reshape(-1, dim) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(bsz, seqlen) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_physical_layers=args.num_physical_layers, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden_override=args.mlp_hidden_override, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # QAT control + global _QAT_ENABLED + if args.qat_mode == 1: + _QAT_ENABLED = True + elif args.qat_mode == 2: + _QAT_ENABLED = scale < 1.0 + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + # CAT regularizer (separate backward pass, outside torch.compile region). + # Uses base_model (not DDP wrapper) — bypasses DDP allreduce hooks. + # This is correct because: (1) DDP keeps weights synchronized across ranks, + # (2) seed=step ensures identical subsampling, so all ranks compute identical + # CAT gradients independently. If either invariant breaks, ranks will diverge. + cat_loss_val = 0.0 + if args.cat_lambda_lz > 0 or args.cat_lambda_h > 0: + if step >= args.cat_start_step: + cat_device = next(base_model.parameters()).device + byte_stream = serialize_quantized_weights_torch(base_model) + cat_loss = torch.tensor(0.0, device=cat_device) + if args.cat_lambda_lz > 0: + cat_loss = cat_loss + args.cat_lambda_lz * lz77_proxy_loss_torch( + byte_stream, args.cat_temperature, args.cat_sample_size, seed=step) + if args.cat_lambda_h > 0: + cat_loss = cat_loss + args.cat_lambda_h * entropy_proxy_loss_torch( + byte_stream, args.cat_bandwidth, args.cat_sample_size, seed=step) + cat_loss.backward() + cat_loss_val = cat_loss.item() + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + cat_log_str = f" cat_loss:{cat_loss_val:.6f}" if cat_loss_val != 0.0 else "" + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f}{cat_log_str} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # CAT black-box compressed size logging (rank 0 only) + if (args.cat_lambda_lz > 0 or args.cat_lambda_h > 0) and args.cat_log_every > 0: + if step > 0 and step % args.cat_log_every == 0 and local_rank == 0: + with torch.no_grad(): + byte_stream_log = serialize_quantized_weights_torch(base_model) + lz_val = lz77_proxy_loss_torch(byte_stream_log.detach(), args.cat_temperature, args.cat_sample_size, seed=step).item() + ent_val = entropy_proxy_loss_torch(byte_stream_log.detach(), args.cat_bandwidth, args.cat_sample_size, seed=step).item() + # Actual zlib size + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + zlib_size = len(zlib.compress(quant_buf.getvalue(), level=9)) + log0(f"step:{step} | zlib:{zlib_size:,}B lz77_proxy:{lz_val:.4f} entropy_proxy:{ent_val:.4f}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Strided eval (document-isolated, overlapping windows) — PR #77 approach + if STRIDE_CHUNK > 0: + t_strided = time.perf_counter() + s_val_loss, s_val_bpb = strided_val( + args, model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + chunk_size=STRIDE_CHUNK, window_size=args.train_seq_len, + ) + torch.cuda.synchronize() + log0( + f"strided_eval val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} " + f"chunk:{STRIDE_CHUNK} eval_time:{1000.0 * (time.perf_counter() - t_strided):.0f}ms" + ) + log0(f"strided_eval_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()