diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/README.md b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/README.md new file mode 100644 index 0000000000..93f6bb98f3 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/README.md @@ -0,0 +1,122 @@ +# Non-record: #1610 reproduction (Δ=+1.9e-5 BPB), n-gram posterior corrector negative result, quantized-eval-only path fix + +This folder contains a non-record evidence package built around three separable contributions. It is not a leaderboard claim. + +## Prior context + +Previous submissions in this line: [#1101](https://github.com/openai/parameter-golf/pull/1101) (pre-TTT anchor, 1.1290 BPB), [#1307](https://github.com/openai/parameter-golf/pull/1307) (07c1 strict base proof vs merged #1019), [#1598](https://github.com/openai/parameter-golf/pull/1598) (SP8192-D 5-seed evidence package). + +## Contributions + +1. **Reproduction of PR #1610 on independent infrastructure.** Seed-0 BPB is 1.07218477, which differs from #1610's published seed-0 number 1.07216564 by +1.913×10⁻⁵ BPB. Training was run on 8× NVIDIA H100 80GB HBM3 SXM5 on RunPod at this branch's commit `1765afc`, which pins PR #1610 at its upstream commit `ca19195`. +2. **Bounded negative result for a score-first n-gram posterior corrector** layered on PR #1610's phased LoRA TTT eval path. All three tested `(alpha, orders)` configurations degrade BPB relative to the reproduced baseline. Damage scales monotonically with the blend weight `alpha`. Multi-order backoff provides no measurable benefit over single-order at the same `alpha` in this grid. +3. **Bug fix in the quantized-eval-only branch of `train_gpt.py`.** The pre-quantization diagnostic eval previously ran unconditionally, dereferencing `compiled_model = None` in `eval_only_quantized_path` mode. The fix wraps the diagnostic in `if not quantized_eval_only:` (line 3204) and extends the `del eval_model, compiled_model` cleanup guard to cover the same branch (line 3259). Without these two guards, `EVAL_ONLY_QUANTIZED_PATH` ablations could not run. They produced the measurements in Contribution 2. + +The reproduction is a credibility prerequisite for the negative-result claim, not a contribution in itself. The corrector formulation and its Section-III-compliance engineering are the only novel content in this package. The bug fix is incidental — surfaced while running Contribution 2. + +## Reproduction result + +| | Value | +|---|---| +| Our seed-0 BPB | 1.07218477 | +| PR #1610 published seed-0 BPB | 1.07216564 | +| Δ vs published seed-0 | **+1.913×10⁻⁵** | +| Eval wall-clock (s) | 455.9 | +| Artifact bytes (model + code) | 15,999,394 | + +Training stopped at step 4,879 of 20,000 because of `MAX_WALLCLOCK_SECONDS=600` (`minus GPTQ_RESERVE_SECONDS=13`). This is the by-design behavior of #1610. The artifact at 15,999,394 bytes leaves 606 bytes of competition headroom; our internal pipeline's stricter 15,997,520-byte threshold (intended to absorb code-size drift between sessions) is the source of the `GATE_A: FAIL` line in the log tail. The full run log is in `train_seed0.log`. Machine-readable summary is in `reproduction_summary.json`. + +## Corrector ablation + +The corrector is a backward-looking posterior over the scored prefix. At position `t` it maintains a Laplace-smoothed unigram distribution `q_uni(v)` and, for the requested n-gram orders, conditional counts `q_ngram(v | ctx)`. The blend adds `alpha * log(q_t(v))` to the neural logits before scoring, where `q_t(v) = q_ngram(v | ctx_t)` when an n-gram hit exists at order `n` and `q_t(v) = q_uni(v)` otherwise (Laplace guarantees `q_uni(v) > 0` for all `v`). The corrector integrates into the phased TTT loop at `forward_ttt(..., logit_bias=...)` (`train_gpt.py:1048`); the actual blend is the single line `logits = logits + logit_bias` (`train_gpt.py:1122`). + +All three ablations were run in eval-only mode against the seed-0 checkpoint from the reproduction above (no retraining). + +| Run | alpha | orders | BPB | Δ BPB (run − baseline; positive = worse) | Eval (s) | +|---|---:|---|---:|---:|---:| +| Baseline (no corrector) | 0.0 | — | 1.07218477 | 0 | 455.9 | +| Ablation 1a | 0.3 | [8] | 1.08876294 | **+0.01658** | 462.8 | +| Ablation 1b | 0.3 | [5, 8, 12] | 1.08891256 | **+0.01673** | 472.4 | +| Ablation 1c | 0.1 | [5, 8, 12] | 1.07430360 | **+0.00212** | 465.8 | + +**Interpretation.** The corrector's effect at `alpha = 0.1` is approximately 1/8 of its effect at `alpha = 0.3`, consistent with first-order linearity in `alpha` and inconsistent with any threshold-activated improvement at lower blend weights. Multi-order backoff at the same `alpha` produced a negligible delta (`+0.01658` for `[8]` vs `+0.01673` for `[5, 8, 12]`). + +Structurally, TTT-LoRA adapts the base model's output distribution using the same scored prefix `x_{1..t-1}` that feeds the corrector's n-gram tables. Both signals are therefore deterministic functions of that prefix. Adding `alpha * log(q_prefix_ngram(v))` on top of logits that already encode `P(x_t | x_{1..t-1})` under TTT adaptation over-counts the prefix evidence; the corrector's positive coefficient systematically pushes probability mass toward tokens the base model has already concluded are likely. This predicts the result is monotonic in `alpha` and that a corrector layered on a non-adaptive (non-TTT) eval pipeline would behave differently. The latter was not tested in this package. + +**This PR rules out one tested posterior-corrector path on a reproduced #1610-class phased-TTT stack; it does not claim that all n-gram or posterior correctors are ineffective.** + +Raw eval logs are in `ablation_1a.log`, `ablation_1b.log`, `ablation_1c.log`. Machine-readable config + results in `ablation_summary.json`. + +## Eval-only bug fix + +Two `train_gpt.py` branches went through `base_model = None; compiled_model = None; compiled_forward_logits = None` in `EVAL_ONLY_QUANTIZED_PATH` mode (line 3188) but were then used by downstream eval code: + +- The pre-quantization diagnostic `timed_eval("diagnostic pre-quantization post-ema", eval_val, ..., compiled_model, ...)` dereferenced `compiled_model.forward_logits` and crashed all ranks with `AttributeError: 'NoneType' ...`. +- The subsequent `del eval_model, compiled_model` cleanup in the TTT branch referenced `eval_model` which was never bound in this mode, raising `UnboundLocalError`. + +The fix adds two guards: + +``` +# train_gpt.py:3204 +if not quantized_eval_only: + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, h, device, val_data, + compiled_model, compiled_forward_logits, + ) + +# train_gpt.py:3259 +if h.ttt_enabled: + if not ttt_only_eval and not quantized_eval_only: + del eval_model, compiled_model +``` + +The post-quantization diagnostic still runs in this branch because it calls `deserialize(h, device)` directly and does not touch the `None` locals. + +## Compliance with Issue #1017 Section III + +Each of the four conditions walked against the as-shipped corrector code in `train_gpt.py`. + +**Condition 1 (strict causal dependence).** State is the `PrefixNgramCorrector` instance defined at `train_gpt.py:15-58`. Its `get_logit_bias()` reads only `self.hist` and `self.uni`, which are populated exclusively inside `update()`. `update(x_t)` is called in the eval loop *after* `F.cross_entropy` scores position `t`. Nothing else writes into `self.hist` or `self.uni`. No future tokens, no external data. + +**Condition 2 (full normalized distribution).** The blend is `final_logits[v] = neural_logits[v] + alpha * log(q_t(v))` over the full `V = vocab_size` alphabet (`train_gpt.py:1122`, `[V]` tensor add — not a gathered single-index update). Laplace smoothing at init (`self.uni = torch.ones(V, dtype=torch.int32)` at line 23) guarantees `q_uni(v) > 0` for every `v ∈ V`, so `log(q_t(v))` is finite everywhere and `softmax(final_logits)` is a valid distribution. The `[V]` bias vector is `alpha * (self._lu - self._lz)` at line 34, with the n-gram delta sparsely added over tokens that have n-gram hits; the final effective bias is `alpha * log(q_ngram)` where n-gram hits exist and `alpha * log(q_uni)` otherwise. No dense `[batch, seq, vocab]` allocation exists on the production path; the `[B, 1, V]` bias is broadcast at `train_gpt.py:2565`. + +**Condition 3 (score-before-update).** In the phased TTT eval loop (`train_gpt.py:2562-2598`): the bias is collected from `[correctors[_b].get_logit_bias() ...]` (line 2564) *before* the scoring forward pass `forward_ttt_train(..., logit_bias=_logit_bias)` (line 2567). Scores are accumulated in `_accumulate_bpb(...)` (lines 2568-2582). Only then does the update path run: `correctors[_b].update(_tok)` (line 2591), inside the block introduced by the explicit comment `# Corrector: update state with scored tokens (score-before-update)` (line 2583). `PrefixNgramCorrector`'s docstring (lines 17-18) encodes the same contract and the call sites honor it. + +Note on the chunk-static approximation. The bias is computed once per TTT chunk (chunk size = `h.ttt_chunk_size = 32`) and broadcast as `[B, 1, V]` to every position inside that chunk. This is a deliberate engineering choice: a per-position bias would require either 32× as many GPU forward passes across the eval set (blowing past the 600 s budget by an unrecoverable margin) or a dense `[B, S, V]` correction tensor with `B=64`, `S=2048`, `V=8192` in bf16 ≈ 2.1 GB per batch per rank — unusable inside the #1610 memory envelope once stacked with activations and optimizer state. The trade-off is accepted because the bias at any position inside chunk `c` is still a function only of tokens from chunks `[0, c)`, which preserves score-before-update at chunk granularity. Within a chunk, the bias is constant — it does not use tokens from the current chunk for its own scoring. This satisfies score-before-update at chunk granularity rather than per-position, and the choice is explicit in the corrector's docstring. + +**Condition 4 (single left-to-right pass).** `eval_val_ttt_phased` is one forward pass over the validation token stream. No re-scoring, no second pass, no min-over-runs selection. The phased TTT loop performs interleaved global SGD steps on the base model between chunks, but those SGD steps do not re-score previously scored positions. After each global SGD step, the corrector state is reset (`correctors[_b].reset()`) to avoid stale counts against the updated base model. + +**Warmup uses synthetic tokens only.** The TTT compile warmup (`train_gpt.py:3324-3365`) is bracketed by `# BEGIN warmup synthetic tokens` / `# END warmup synthetic tokens` comments and uses a device-local generator (`_warmup_gen = torch.Generator(device=device).manual_seed(0)`) that does not mutate global RNG state. Tokens are drawn via `torch.randint(0, h.vocab_size, ..., generator=_warmup_gen)`. When `corrector_alpha > 0`, a second warmup pass with `dummy_bias = torch.zeros(bsz, 1, h.vocab_size, ...)` precompiles the `logit_bias` branch so Dynamo does not recompile inside the eval timer. The timer starts at `torch.cuda.synchronize(); t_ttt = time.perf_counter()` (`train_gpt.py:3370-3371`) *after* the warmup block closes. + +## Out of scope / open questions + +- `alpha < 0.1` not tested: the trend from `alpha = 0.1 → 0.3` suggests negligible effect at lower blend weights, but this was not measured. +- Orders greater than 12 not tested: longer contexts could catch different co-occurrence structure; compute scaling of the C++ hash-table path at higher orders was not characterized in this package. +- Logistic-domain (log-odds) blend alternatives to the probability-domain blend here were not tested. +- Non-TTT eval pipelines were not tested; the negative result is conditional on the phased-LoRA-TTT stack. + +## Single-seed scope + +This package reports a faithful seed-0 reproduction plus eval-only ablations; it is a non-record evidence package and not a leaderboard claim. This submission uses seed 0 only, both for the reproduction and for the three corrector ablations. The reproduction is compared against #1610's published seed-0 number (1.07216564), not against their 3-seed mean. Multi-seed validation was descoped: with a +1.9×10⁻⁵ BPB delta against the matched seed and a monotonic +0.002 to +0.017 degradation across the corrector grid, additional seeds would refine the variance estimate but are unlikely to flip either direction. The negative-result claim is therefore bounded to seed 0 of the reproduced #1610 checkpoint. + +## Artifacts and reproducibility + +| File | What it is | +|---|---| +| `train_gpt.py` | This PR's training script; pinned to #1610 upstream with the eval-only-quantized guards applied. | +| `train_seed0.log` | Raw training log for the seed-0 reproduction (script-level timing + per-step metrics; training script writes compact output by design). | +| `ablation_1a.log`, `ablation_1b.log`, `ablation_1c.log` | Raw eval logs for the three corrector configurations (same logging convention). | +| `reproduction_summary.json` | Machine-readable reproduction metrics. | +| `ablation_summary.json` | Machine-readable corrector ablation results (all three configs). | +| `submission.json` | Non-record submission metadata. | +| `requirements.txt` | Python dependencies; pins `torch==2.9.1+cu128`; FA3 notes inline. | +| `provenance/commit_sha.txt` | This branch's commit SHA. | +| `provenance/env_fingerprint.txt` | Torch / CUDA / Python versions at run time. | +| `provenance/hardware_info.txt` | `nvidia-smi` output captured at Gate A. | + +**Commit SHA.** `1765afc7d62ce03a1219ca81cc92eea4fabdf343` (pins PR #1610 at upstream `ca1919539dc6e328ea890cb03ad3ca1c5a84da55`; plus eval-only-quantized guards at `e99f18e`). + +**Hardware.** 8× NVIDIA H100 80GB HBM3, SXM5, CUDA 12.8, driver 570.211.01. + +Supplementary external artifact archive for reproducibility: . Contains the preserved full run tarball (141 MB, MD5 `caf8adf63d8c80965f6671beba95d7aa`): pre-quantization checkpoint, quantized checkpoint, full ablation intermediate artifacts. Not required to reproduce the headline number — `train_gpt.py` and the logs in this folder are self-sufficient. diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1a.log b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1a.log new file mode 100644 index 0000000000..911bc94dce --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1a.log @@ -0,0 +1,441 @@ +NCCL version 2.27.5+cuda12.9 +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: /workspace/parameter-golf/runs/ablation_1a + beta1: 0.9 + beta2: 0.95 + compressor: brotli + corrector_alpha: 0.3 + corrector_orders: 8 + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_only_quantized_path: /workspace/checkpoints/seed0/final_model.int6.ptz + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_doc_limit: 0 + global_ttt_epochs: 3 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.005 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/runs/ablation_1a/a59d6775-e943-4ab3-a964-6a763c9f7a96.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/ablation_1a/final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: /workspace/checkpoints/seed0/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: a59d6775-e943-4ab3-a964-6a763c9f7a96 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 32 + ttt_doc_limit: 0 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_heartbeat_seconds: 15.0 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +eval_only:using quantized checkpoint from /workspace/checkpoints/seed0/final_model.int6.ptz +eval_only: skipping serialize (already have quantized model) +diagnostic quantized val_loss:2.79886366 val_bpb:1.08349242 eval_time:12149ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (182.9s) + +beginning TTT eval timer +corrector: alpha=0.3 orders=[8] +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 +ttp: b777/782 bl:2.7381 bb:1.0952 rl:2.7381 rb:1.0952 dl:7190-7938 gd:0 +ttp: b772/782 bl:2.7864 bb:1.1145 rl:2.7575 rb:1.1029 dl:4937-5193 gd:0 +ttp: b767/782 bl:2.7818 bb:1.1107 rl:2.7634 rb:1.1048 dl:3963-4123 gd:0 +ttp: b761/782 bl:2.7797 bb:1.0753 rl:2.7662 rb:1.0997 dl:3336-3430 gd:0 +ttp: b755/782 bl:2.7240 bb:1.0554 rl:2.7608 rb:1.0939 dl:2899-2972 gd:0 +ttp: b749/782 bl:2.8658 bb:1.1026 rl:2.7715 rb:1.0948 dl:2580-2638 gd:0 +ttpp: pd:2448 gd:2000 t:208.0s +tttg: c1/213 lr:0.005000 t:0.6s +tttg: c2/213 lr:0.005000 t:0.9s +tttg: c3/213 lr:0.004999 t:1.2s +tttg: c4/213 lr:0.004998 t:1.4s +tttg: c5/213 lr:0.004996 t:1.7s +tttg: c6/213 lr:0.004993 t:2.0s +tttg: c7/213 lr:0.004990 t:2.3s +tttg: c8/213 lr:0.004987 t:2.6s +tttg: c9/213 lr:0.004982 t:2.9s +tttg: c10/213 lr:0.004978 t:3.2s +tttg: c11/213 lr:0.004973 t:3.5s +tttg: c12/213 lr:0.004967 t:3.8s +tttg: c13/213 lr:0.004961 t:4.0s +tttg: c14/213 lr:0.004954 t:4.3s +tttg: c15/213 lr:0.004946 t:4.6s +tttg: c16/213 lr:0.004938 t:4.9s +tttg: c17/213 lr:0.004930 t:5.1s +tttg: c18/213 lr:0.004921 t:5.4s +tttg: c19/213 lr:0.004912 t:5.7s +tttg: c20/213 lr:0.004902 t:6.0s +tttg: c21/213 lr:0.004891 t:6.3s +tttg: c22/213 lr:0.004880 t:6.6s +tttg: c23/213 lr:0.004868 t:6.8s +tttg: c24/213 lr:0.004856 t:7.1s +tttg: c25/213 lr:0.004844 t:7.4s +tttg: c26/213 lr:0.004830 t:7.7s +tttg: c27/213 lr:0.004817 t:8.0s +tttg: c28/213 lr:0.004803 t:8.3s +tttg: c29/213 lr:0.004788 t:8.5s +tttg: c30/213 lr:0.004773 t:8.8s +tttg: c31/213 lr:0.004757 t:9.1s +tttg: c32/213 lr:0.004741 t:9.4s +tttg: c33/213 lr:0.004724 t:9.6s +tttg: c34/213 lr:0.004707 t:9.9s +tttg: c35/213 lr:0.004689 t:10.2s +tttg: c36/213 lr:0.004671 t:10.5s +tttg: c37/213 lr:0.004653 t:10.7s +tttg: c38/213 lr:0.004634 t:14.2s +tttg: c39/213 lr:0.004614 t:14.5s +tttg: c40/213 lr:0.004594 t:14.8s +tttg: c41/213 lr:0.004574 t:15.1s +tttg: c42/213 lr:0.004553 t:15.4s +tttg: c43/213 lr:0.004531 t:15.7s +tttg: c44/213 lr:0.004509 t:15.9s +tttg: c45/213 lr:0.004487 t:16.2s +tttg: c46/213 lr:0.004464 t:16.5s +tttg: c47/213 lr:0.004441 t:16.8s +tttg: c48/213 lr:0.004418 t:17.0s +tttg: c49/213 lr:0.004394 t:17.3s +tttg: c50/213 lr:0.004369 t:17.6s +tttg: c51/213 lr:0.004345 t:17.9s +tttg: c52/213 lr:0.004319 t:18.1s +tttg: c53/213 lr:0.004294 t:18.4s +tttg: c54/213 lr:0.004268 t:18.7s +tttg: c55/213 lr:0.004241 t:19.0s +tttg: c56/213 lr:0.004215 t:19.2s +tttg: c57/213 lr:0.004187 t:19.5s +tttg: c58/213 lr:0.004160 t:19.8s +tttg: c59/213 lr:0.004132 t:20.1s +tttg: c60/213 lr:0.004104 t:20.4s +tttg: c61/213 lr:0.004075 t:20.6s +tttg: c62/213 lr:0.004046 t:20.9s +tttg: c63/213 lr:0.004017 t:21.2s +tttg: c64/213 lr:0.003987 t:21.5s +tttg: c65/213 lr:0.003957 t:21.7s +tttg: c66/213 lr:0.003927 t:22.0s +tttg: c67/213 lr:0.003897 t:22.2s +tttg: c68/213 lr:0.003866 t:22.5s +tttg: c69/213 lr:0.003835 t:22.7s +tttg: c70/213 lr:0.003803 t:23.0s +tttg: c71/213 lr:0.003771 t:23.2s +tttg: c72/213 lr:0.003739 t:23.4s +tttg: c73/213 lr:0.003707 t:23.7s +tttg: c74/213 lr:0.003674 t:23.9s +tttg: c75/213 lr:0.003642 t:24.2s +tttg: c76/213 lr:0.003608 t:24.4s +tttg: c77/213 lr:0.003575 t:24.7s +tttg: c78/213 lr:0.003542 t:24.9s +tttg: c79/213 lr:0.003508 t:25.2s +tttg: c80/213 lr:0.003474 t:25.4s +tttg: c81/213 lr:0.003440 t:25.6s +tttg: c82/213 lr:0.003405 t:25.9s +tttg: c83/213 lr:0.003371 t:26.1s +tttg: c84/213 lr:0.003336 t:26.4s +tttg: c85/213 lr:0.003301 t:26.6s +tttg: c86/213 lr:0.003265 t:26.9s +tttg: c87/213 lr:0.003230 t:27.1s +tttg: c88/213 lr:0.003195 t:27.3s +tttg: c89/213 lr:0.003159 t:27.6s +tttg: c90/213 lr:0.003123 t:27.8s +tttg: c91/213 lr:0.003087 t:28.0s +tttg: c92/213 lr:0.003051 t:28.2s +tttg: c93/213 lr:0.003015 t:28.5s +tttg: c94/213 lr:0.002979 t:28.7s +tttg: c95/213 lr:0.002942 t:28.9s +tttg: c96/213 lr:0.002906 t:29.1s +tttg: c97/213 lr:0.002869 t:29.4s +tttg: c98/213 lr:0.002832 t:29.6s +tttg: c99/213 lr:0.002796 t:29.8s +tttg: c100/213 lr:0.002759 t:30.1s +tttg: c101/213 lr:0.002722 t:30.3s +tttg: c102/213 lr:0.002685 t:30.5s +tttg: c103/213 lr:0.002648 t:30.8s +tttg: c104/213 lr:0.002611 t:31.0s +tttg: c105/213 lr:0.002574 t:31.2s +tttg: c106/213 lr:0.002537 t:31.4s +tttg: c107/213 lr:0.002500 t:31.7s +tttg: c108/213 lr:0.002463 t:31.9s +tttg: c109/213 lr:0.002426 t:32.1s +tttg: c110/213 lr:0.002389 t:32.4s +tttg: c111/213 lr:0.002352 t:32.6s +tttg: c112/213 lr:0.002315 t:32.9s +tttg: c113/213 lr:0.002278 t:33.1s +tttg: c114/213 lr:0.002241 t:33.3s +tttg: c115/213 lr:0.002204 t:33.6s +tttg: c116/213 lr:0.002168 t:33.8s +tttg: c117/213 lr:0.002131 t:34.0s +tttg: c118/213 lr:0.002094 t:34.3s +tttg: c119/213 lr:0.002058 t:34.5s +tttg: c120/213 lr:0.002021 t:34.7s +tttg: c121/213 lr:0.001985 t:35.0s +tttg: c122/213 lr:0.001949 t:35.2s +tttg: c123/213 lr:0.001913 t:35.4s +tttg: c124/213 lr:0.001877 t:35.7s +tttg: c125/213 lr:0.001841 t:35.9s +tttg: c126/213 lr:0.001805 t:36.1s +tttg: c127/213 lr:0.001770 t:36.4s +tttg: c128/213 lr:0.001735 t:36.6s +tttg: c129/213 lr:0.001699 t:36.8s +tttg: c130/213 lr:0.001664 t:37.0s +tttg: c131/213 lr:0.001629 t:37.3s +tttg: c132/213 lr:0.001595 t:37.5s +tttg: c133/213 lr:0.001560 t:37.7s +tttg: c134/213 lr:0.001526 t:37.9s +tttg: c135/213 lr:0.001492 t:38.2s +tttg: c136/213 lr:0.001458 t:38.4s +tttg: c137/213 lr:0.001425 t:38.6s +tttg: c138/213 lr:0.001392 t:38.9s +tttg: c139/213 lr:0.001358 t:39.1s +tttg: c140/213 lr:0.001326 t:39.3s +tttg: c141/213 lr:0.001293 t:39.6s +tttg: c142/213 lr:0.001261 t:39.8s +tttg: c143/213 lr:0.001229 t:40.0s +tttg: c144/213 lr:0.001197 t:40.2s +tttg: c145/213 lr:0.001165 t:40.5s +tttg: c146/213 lr:0.001134 t:40.7s +tttg: c147/213 lr:0.001103 t:40.9s +tttg: c148/213 lr:0.001073 t:41.2s +tttg: c149/213 lr:0.001043 t:41.4s +tttg: c150/213 lr:0.001013 t:41.6s +tttg: c151/213 lr:0.000983 t:41.8s +tttg: c152/213 lr:0.000954 t:42.1s +tttg: c153/213 lr:0.000925 t:42.3s +tttg: c154/213 lr:0.000896 t:42.5s +tttg: c155/213 lr:0.000868 t:42.8s +tttg: c156/213 lr:0.000840 t:43.0s +tttg: c157/213 lr:0.000813 t:43.2s +tttg: c158/213 lr:0.000785 t:43.5s +tttg: c159/213 lr:0.000759 t:43.7s +tttg: c160/213 lr:0.000732 t:43.9s +tttg: c161/213 lr:0.000706 t:44.1s +tttg: c162/213 lr:0.000681 t:44.4s +tttg: c163/213 lr:0.000655 t:44.6s +tttg: c164/213 lr:0.000631 t:44.8s +tttg: c165/213 lr:0.000606 t:45.1s +tttg: c166/213 lr:0.000582 t:45.3s +tttg: c167/213 lr:0.000559 t:45.5s +tttg: c168/213 lr:0.000536 t:45.7s +tttg: c169/213 lr:0.000513 t:46.0s +tttg: c170/213 lr:0.000491 t:46.2s +tttg: c171/213 lr:0.000469 t:46.4s +tttg: c172/213 lr:0.000447 t:46.6s +tttg: c173/213 lr:0.000426 t:46.9s +tttg: c174/213 lr:0.000406 t:47.1s +tttg: c175/213 lr:0.000386 t:47.3s +tttg: c176/213 lr:0.000366 t:47.5s +tttg: c177/213 lr:0.000347 t:47.8s +tttg: c178/213 lr:0.000329 t:48.0s +tttg: c179/213 lr:0.000311 t:48.2s +tttg: c180/213 lr:0.000293 t:48.4s +tttg: c181/213 lr:0.000276 t:49.5s +tttg: c182/213 lr:0.000259 t:49.8s +tttg: c183/213 lr:0.000243 t:50.0s +tttg: c184/213 lr:0.000227 t:50.2s +tttg: c185/213 lr:0.000212 t:50.5s +tttg: c186/213 lr:0.000197 t:50.7s +tttg: c187/213 lr:0.000183 t:50.9s +tttg: c188/213 lr:0.000170 t:51.1s +tttg: c189/213 lr:0.000156 t:51.4s +tttg: c190/213 lr:0.000144 t:51.6s +tttg: c191/213 lr:0.000132 t:51.8s +tttg: c192/213 lr:0.000120 t:52.8s +tttg: c193/213 lr:0.000109 t:53.1s +tttg: c194/213 lr:0.000098 t:53.3s +tttg: c195/213 lr:0.000088 t:53.5s +tttg: c196/213 lr:0.000079 t:53.7s +tttg: c197/213 lr:0.000070 t:54.0s +tttg: c198/213 lr:0.000062 t:54.2s +tttg: c199/213 lr:0.000054 t:54.4s +tttg: c200/213 lr:0.000046 t:54.6s +tttg: c201/213 lr:0.000039 t:54.9s +tttg: c202/213 lr:0.000033 t:55.1s +tttg: c203/213 lr:0.000027 t:55.3s +tttg: c204/213 lr:0.000022 t:55.5s +tttg: c205/213 lr:0.000018 t:55.8s +tttg: c206/213 lr:0.000013 t:56.0s +tttg: c207/213 lr:0.000010 t:56.2s +tttg: c208/213 lr:0.000007 t:56.4s +tttg: c209/213 lr:0.000004 t:56.7s +tttg: c210/213 lr:0.000002 t:56.9s +tttg: c211/213 lr:0.000001 t:57.1s +tttg: c212/213 lr:0.000000 t:57.3s +ttpr: t:267.9s +ttp: b736/782 bl:2.6862 bb:1.0470 rl:2.7649 rb:1.0911 dl:2140-2165 gd:1 +ttp: b728/782 bl:2.7799 bb:1.0766 rl:2.7659 rb:1.0901 dl:1960-1977 gd:1 +ttp: b720/782 bl:2.8420 bb:1.0856 rl:2.7703 rb:1.0898 dl:1816-1832 gd:1 +ttp: b712/782 bl:2.8588 bb:1.0884 rl:2.7748 rb:1.0897 dl:1684-1697 gd:1 +ttp: b708/782 bl:2.7495 bb:1.0566 rl:2.7736 rb:1.0881 dl:1639-1649 gd:1 +ttp: b693/782 bl:2.8490 bb:1.1176 rl:2.7767 rb:1.0894 dl:1485-1494 gd:1 +ttp: b686/782 bl:2.8349 bb:1.0651 rl:2.7789 rb:1.0884 dl:1422-1432 gd:1 +ttp: b679/782 bl:2.8860 bb:1.0995 rl:2.7826 rb:1.0888 dl:1368-1374 gd:1 +ttp: b671/782 bl:2.9159 bb:1.1301 rl:2.7870 rb:1.0902 dl:1316-1321 gd:1 +ttp: b665/782 bl:2.7753 bb:1.0459 rl:2.7866 rb:1.0888 dl:1275-1282 gd:1 +ttp: b655/782 bl:2.7216 bb:1.0353 rl:2.7848 rb:1.0872 dl:1215-1220 gd:1 +ttp: b648/782 bl:2.7854 bb:1.0558 rl:2.7848 rb:1.0863 dl:1177-1182 gd:1 +ttp: b641/782 bl:2.8016 bb:1.0548 rl:2.7852 rb:1.0855 dl:1140-1144 gd:1 +ttp: b634/782 bl:2.7399 bb:1.0577 rl:2.7841 rb:1.0848 dl:1105-1111 gd:1 +ttp: b626/782 bl:2.8525 bb:1.0599 rl:2.7857 rb:1.0843 dl:1068-1073 gd:1 +ttp: b618/782 bl:2.7819 bb:1.0665 rl:2.7856 rb:1.0839 dl:1031-1037 gd:1 +ttp: b610/782 bl:2.8787 bb:1.0808 rl:2.7875 rb:1.0838 dl:999-1004 gd:1 +ttp: b602/782 bl:2.8247 bb:1.0566 rl:2.7882 rb:1.0833 dl:966-971 gd:1 +ttp: b589/782 bl:2.7959 bb:1.0703 rl:2.7883 rb:1.0830 dl:921-924 gd:1 +ttp: b581/782 bl:2.7654 bb:1.0316 rl:2.7879 rb:1.0821 dl:894-897 gd:1 +ttp: b573/782 bl:2.9746 bb:1.0887 rl:2.7910 rb:1.0822 dl:868-871 gd:1 +ttp: b565/782 bl:2.8187 bb:1.0801 rl:2.7914 rb:1.0822 dl:843-846 gd:1 +ttp: b557/782 bl:2.8435 bb:1.0602 rl:2.7922 rb:1.0818 dl:818-821 gd:1 +ttp: b549/782 bl:2.8098 bb:1.0810 rl:2.7925 rb:1.0818 dl:795-798 gd:1 +ttp: b542/782 bl:2.8820 bb:1.0917 rl:2.7937 rb:1.0820 dl:777-779 gd:1 +ttp: b535/782 bl:2.8350 bb:1.0749 rl:2.7942 rb:1.0819 dl:759-762 gd:1 +ttp: b527/782 bl:2.7889 bb:1.0598 rl:2.7942 rb:1.0816 dl:739-742 gd:1 +ttp: b520/782 bl:2.8436 bb:1.0777 rl:2.7948 rb:1.0815 dl:723-725 gd:1 +ttp: b513/782 bl:2.7855 bb:1.0310 rl:2.7947 rb:1.0809 dl:705-707 gd:1 +ttp: b505/782 bl:2.8256 bb:1.0796 rl:2.7950 rb:1.0809 dl:686-688 gd:1 +ttp: b498/782 bl:2.7339 bb:1.0584 rl:2.7943 rb:1.0807 dl:671-673 gd:1 +ttp: b489/782 bl:2.8480 bb:1.1012 rl:2.7949 rb:1.0809 dl:651-653 gd:1 +ttp: b481/782 bl:2.8548 bb:1.1224 rl:2.7955 rb:1.0813 dl:635-637 gd:1 +ttp: b472/782 bl:2.8569 bb:1.0921 rl:2.7961 rb:1.0814 dl:616-618 gd:1 +ttp: b463/782 bl:2.8546 bb:1.0958 rl:2.7967 rb:1.0815 dl:599-600 gd:1 +ttp: b455/782 bl:2.8586 bb:1.0962 rl:2.7973 rb:1.0817 dl:584-586 gd:1 +ttp: b447/782 bl:2.8928 bb:1.1124 rl:2.7981 rb:1.0819 dl:569-571 gd:1 +ttp: b439/782 bl:2.8156 bb:1.0671 rl:2.7983 rb:1.0818 dl:555-556 gd:1 +ttp: b432/782 bl:2.8171 bb:1.0718 rl:2.7984 rb:1.0817 dl:542-544 gd:1 +ttp: b424/782 bl:2.8618 bb:1.1061 rl:2.7989 rb:1.0819 dl:528-530 gd:1 +ttp: b411/782 bl:2.8733 bb:1.0955 rl:2.7995 rb:1.0820 dl:507-508 gd:1 +ttp: b402/782 bl:2.8174 bb:1.0615 rl:2.7996 rb:1.0819 dl:492-493 gd:1 +ttp: b394/782 bl:2.9578 bb:1.1407 rl:2.8007 rb:1.0823 dl:479-481 gd:1 +ttp: b387/782 bl:2.8916 bb:1.0946 rl:2.8014 rb:1.0824 dl:468-470 gd:1 +ttp: b380/782 bl:2.9083 bb:1.1012 rl:2.8021 rb:1.0825 dl:459-460 gd:1 +ttp: b372/782 bl:2.9023 bb:1.0941 rl:2.8027 rb:1.0826 dl:447-449 gd:1 +ttp: b364/782 bl:2.7950 bb:1.0902 rl:2.8027 rb:1.0826 dl:436-437 gd:1 +ttp: b357/782 bl:2.9301 bb:1.1087 rl:2.8035 rb:1.0828 dl:426-427 gd:1 +ttp: b348/782 bl:2.8771 bb:1.0933 rl:2.8039 rb:1.0829 dl:414-415 gd:1 +ttp: b340/782 bl:2.8867 bb:1.1166 rl:2.8044 rb:1.0831 dl:403-404 gd:1 +ttp: b333/782 bl:2.9704 bb:1.1568 rl:2.8053 rb:1.0835 dl:394-395 gd:1 +ttp: b327/782 bl:2.8505 bb:1.1066 rl:2.8055 rb:1.0836 dl:387-388 gd:1 +ttp: b319/782 bl:2.8937 bb:1.1352 rl:2.8060 rb:1.0839 dl:376-377 gd:1 +ttp: b311/782 bl:2.9294 bb:1.1222 rl:2.8066 rb:1.0840 dl:365-367 gd:1 +ttp: b304/782 bl:2.9891 bb:1.1641 rl:2.8075 rb:1.0844 dl:357-358 gd:1 +ttp: b296/782 bl:2.8854 bb:1.1159 rl:2.8079 rb:1.0846 dl:347-348 gd:1 +ttp: b289/782 bl:2.9093 bb:1.1517 rl:2.8084 rb:1.0849 dl:339-340 gd:1 +ttp: b279/782 bl:2.9318 bb:1.1206 rl:2.8089 rb:1.0851 dl:327-329 gd:1 +ttp: b273/782 bl:2.8380 bb:1.0871 rl:2.8090 rb:1.0851 dl:321-322 gd:1 +ttp: b265/782 bl:2.8956 bb:1.1146 rl:2.8094 rb:1.0852 dl:312-313 gd:1 +ttp: b259/782 bl:2.9338 bb:1.1700 rl:2.8099 rb:1.0855 dl:305-306 gd:1 +ttp: b252/782 bl:2.9788 bb:1.1607 rl:2.8106 rb:1.0858 dl:297-298 gd:1 +ttp: b245/782 bl:2.9445 bb:1.1307 rl:2.8111 rb:1.0860 dl:290-291 gd:1 +ttp: b238/782 bl:2.9641 bb:1.1758 rl:2.8117 rb:1.0863 dl:283-284 gd:1 +ttp: b230/782 bl:2.9908 bb:1.1445 rl:2.8123 rb:1.0865 dl:275-276 gd:1 +ttp: b223/782 bl:2.9088 bb:1.1199 rl:2.8127 rb:1.0867 dl:268-269 gd:1 +ttp: b215/782 bl:2.9227 bb:1.1728 rl:2.8131 rb:1.0869 dl:260-261 gd:1 +ttp: b206/782 bl:2.9633 bb:1.1470 rl:2.8136 rb:1.0871 dl:252-253 gd:1 +ttp: b198/782 bl:3.0558 bb:1.1819 rl:2.8143 rb:1.0874 dl:245-246 gd:1 +ttp: b191/782 bl:3.0358 bb:1.1855 rl:2.8150 rb:1.0877 dl:238-239 gd:1 +ttp: b182/782 bl:2.9349 bb:1.1675 rl:2.8154 rb:1.0880 dl:230-231 gd:1 +ttp: b173/782 bl:3.0465 bb:1.1843 rl:2.8160 rb:1.0882 dl:223-224 gd:1 +ttp: b167/782 bl:3.0557 bb:1.2215 rl:2.8167 rb:1.0886 dl:218-218 gd:1 +ttp: b158/782 bl:2.9647 bb:1.1736 rl:2.8171 rb:1.0888 dl:210-211 gd:1 +ttp: b149/782 bl:3.0704 bb:1.2105 rl:2.8178 rb:1.0891 dl:203-204 gd:1 +ttp: b142/782 bl:3.0518 bb:1.1964 rl:2.8184 rb:1.0894 dl:197-198 gd:1 +ttp: b134/782 bl:3.1235 bb:1.2493 rl:2.8191 rb:1.0898 dl:190-191 gd:1 +ttp: b125/782 bl:3.1046 bb:1.2303 rl:2.8198 rb:1.0901 dl:184-185 gd:1 +ttp: b119/782 bl:2.8888 bb:1.1186 rl:2.8199 rb:1.0902 dl:179-180 gd:1 +ttp: b111/782 bl:3.0789 bb:1.2285 rl:2.8205 rb:1.0905 dl:173-174 gd:1 +ttp: b103/782 bl:2.9764 bb:1.1520 rl:2.8208 rb:1.0906 dl:168-168 gd:1 +ttp: b94/782 bl:3.0762 bb:1.2132 rl:2.8213 rb:1.0908 dl:160-161 gd:1 +ttp: b87/782 bl:3.1105 bb:1.2432 rl:2.8219 rb:1.0911 dl:155-156 gd:1 +ttp: b79/782 bl:3.1324 bb:1.2436 rl:2.8225 rb:1.0914 dl:149-150 gd:1 +ttp: b68/782 bl:3.2100 bb:1.2470 rl:2.8232 rb:1.0917 dl:141-142 gd:1 +ttp: b60/782 bl:3.1706 bb:1.2724 rl:2.8238 rb:1.0920 dl:134-135 gd:1 +ttp: b52/782 bl:3.1649 bb:1.2384 rl:2.8243 rb:1.0922 dl:128-129 gd:1 +ttp: b45/782 bl:3.1775 bb:1.2712 rl:2.8248 rb:1.0925 dl:122-123 gd:1 +ttp: b34/782 bl:3.1680 bb:1.2825 rl:2.8253 rb:1.0927 dl:114-115 gd:1 +ttp: b27/782 bl:3.1958 bb:1.2758 rl:2.8258 rb:1.0930 dl:107-108 gd:1 +ttp: b20/782 bl:3.2362 bb:1.3088 rl:2.8263 rb:1.0932 dl:101-102 gd:1 +ttp: b10/782 bl:3.2363 bb:1.2788 rl:2.8268 rb:1.0934 dl:89-90 gd:1 +ttp: b1/782 bl:3.4634 bb:1.2853 rl:2.8273 rb:1.0936 dl:45-70 gd:1 +quantized_ttt_phased val_loss:2.81238776 val_bpb:1.08876294 eval_time:462840ms +total_eval_time:462.8s diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1b.log b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1b.log new file mode 100644 index 0000000000..34d6065562 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1b.log @@ -0,0 +1,436 @@ +NCCL version 2.27.5+cuda12.9 +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: /workspace/parameter-golf/runs/ablation_1b + beta1: 0.9 + beta2: 0.95 + compressor: brotli + corrector_alpha: 0.3 + corrector_orders: 5,8,12 + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_only_quantized_path: /workspace/checkpoints/seed0/final_model.int6.ptz + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_doc_limit: 0 + global_ttt_epochs: 3 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.005 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/runs/ablation_1b/2a4446e2-3000-4080-a044-f0e155c10bb2.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/ablation_1b/final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: /workspace/checkpoints/seed0/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 2a4446e2-3000-4080-a044-f0e155c10bb2 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 32 + ttt_doc_limit: 0 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_heartbeat_seconds: 15.0 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +eval_only:using quantized checkpoint from /workspace/checkpoints/seed0/final_model.int6.ptz +eval_only: skipping serialize (already have quantized model) +diagnostic quantized val_loss:2.79886366 val_bpb:1.08349242 eval_time:12009ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (177.0s) + +beginning TTT eval timer +corrector: alpha=0.3 orders=[5, 8, 12] +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 +ttp: b778/782 bl:2.7965 bb:1.1187 rl:2.7965 rb:1.1187 dl:7961-8997 gd:0 +ttp: b771/782 bl:2.7812 bb:1.0874 rl:2.7909 rb:1.1072 dl:4701-4937 gd:0 +ttp: b766/782 bl:2.5856 bb:1.0123 rl:2.7443 rb:1.0854 dl:3846-3962 gd:0 +ttp: b760/782 bl:2.8734 bb:1.1284 rl:2.7651 rb:1.0924 dl:3255-3334 gd:0 +ttp: b754/782 bl:2.7180 bb:1.0665 rl:2.7593 rb:1.0892 dl:2839-2899 gd:0 +ttp: b748/782 bl:2.8379 bb:1.0867 rl:2.7671 rb:1.0889 dl:2539-2578 gd:0 +ttpp: pd:2448 gd:2000 t:213.5s +tttg: c1/213 lr:0.005000 t:0.6s +tttg: c2/213 lr:0.005000 t:0.9s +tttg: c3/213 lr:0.004999 t:1.2s +tttg: c4/213 lr:0.004998 t:1.5s +tttg: c5/213 lr:0.004996 t:1.8s +tttg: c6/213 lr:0.004993 t:2.0s +tttg: c7/213 lr:0.004990 t:2.3s +tttg: c8/213 lr:0.004987 t:2.6s +tttg: c9/213 lr:0.004982 t:2.9s +tttg: c10/213 lr:0.004978 t:3.2s +tttg: c11/213 lr:0.004973 t:3.5s +tttg: c12/213 lr:0.004967 t:3.8s +tttg: c13/213 lr:0.004961 t:4.1s +tttg: c14/213 lr:0.004954 t:4.3s +tttg: c15/213 lr:0.004946 t:4.7s +tttg: c16/213 lr:0.004938 t:5.0s +tttg: c17/213 lr:0.004930 t:5.3s +tttg: c18/213 lr:0.004921 t:5.6s +tttg: c19/213 lr:0.004912 t:5.9s +tttg: c20/213 lr:0.004902 t:6.2s +tttg: c21/213 lr:0.004891 t:6.4s +tttg: c22/213 lr:0.004880 t:6.7s +tttg: c23/213 lr:0.004868 t:7.0s +tttg: c24/213 lr:0.004856 t:7.3s +tttg: c25/213 lr:0.004844 t:7.6s +tttg: c26/213 lr:0.004830 t:7.8s +tttg: c27/213 lr:0.004817 t:8.1s +tttg: c28/213 lr:0.004803 t:8.4s +tttg: c29/213 lr:0.004788 t:8.7s +tttg: c30/213 lr:0.004773 t:9.0s +tttg: c31/213 lr:0.004757 t:9.2s +tttg: c32/213 lr:0.004741 t:9.5s +tttg: c33/213 lr:0.004724 t:9.8s +tttg: c34/213 lr:0.004707 t:10.1s +tttg: c35/213 lr:0.004689 t:10.4s +tttg: c36/213 lr:0.004671 t:10.7s +tttg: c37/213 lr:0.004653 t:11.0s +tttg: c38/213 lr:0.004634 t:11.2s +tttg: c39/213 lr:0.004614 t:11.5s +tttg: c40/213 lr:0.004594 t:11.8s +tttg: c41/213 lr:0.004574 t:12.1s +tttg: c42/213 lr:0.004553 t:12.4s +tttg: c43/213 lr:0.004531 t:12.6s +tttg: c44/213 lr:0.004509 t:12.9s +tttg: c45/213 lr:0.004487 t:13.2s +tttg: c46/213 lr:0.004464 t:13.5s +tttg: c47/213 lr:0.004441 t:13.8s +tttg: c48/213 lr:0.004418 t:14.0s +tttg: c49/213 lr:0.004394 t:14.3s +tttg: c50/213 lr:0.004369 t:14.6s +tttg: c51/213 lr:0.004345 t:14.9s +tttg: c52/213 lr:0.004319 t:15.2s +tttg: c53/213 lr:0.004294 t:15.4s +tttg: c54/213 lr:0.004268 t:15.7s +tttg: c55/213 lr:0.004241 t:16.0s +tttg: c56/213 lr:0.004215 t:16.3s +tttg: c57/213 lr:0.004187 t:16.5s +tttg: c58/213 lr:0.004160 t:16.8s +tttg: c59/213 lr:0.004132 t:17.1s +tttg: c60/213 lr:0.004104 t:17.4s +tttg: c61/213 lr:0.004075 t:17.7s +tttg: c62/213 lr:0.004046 t:17.9s +tttg: c63/213 lr:0.004017 t:18.2s +tttg: c64/213 lr:0.003987 t:18.5s +tttg: c65/213 lr:0.003957 t:18.8s +tttg: c66/213 lr:0.003927 t:19.0s +tttg: c67/213 lr:0.003897 t:19.3s +tttg: c68/213 lr:0.003866 t:19.6s +tttg: c69/213 lr:0.003835 t:19.8s +tttg: c70/213 lr:0.003803 t:20.0s +tttg: c71/213 lr:0.003771 t:20.3s +tttg: c72/213 lr:0.003739 t:20.5s +tttg: c73/213 lr:0.003707 t:20.7s +tttg: c74/213 lr:0.003674 t:21.0s +tttg: c75/213 lr:0.003642 t:21.2s +tttg: c76/213 lr:0.003608 t:21.5s +tttg: c77/213 lr:0.003575 t:21.7s +tttg: c78/213 lr:0.003542 t:22.0s +tttg: c79/213 lr:0.003508 t:22.2s +tttg: c80/213 lr:0.003474 t:22.5s +tttg: c81/213 lr:0.003440 t:22.7s +tttg: c82/213 lr:0.003405 t:23.0s +tttg: c83/213 lr:0.003371 t:23.2s +tttg: c84/213 lr:0.003336 t:23.5s +tttg: c85/213 lr:0.003301 t:23.7s +tttg: c86/213 lr:0.003265 t:23.9s +tttg: c87/213 lr:0.003230 t:24.2s +tttg: c88/213 lr:0.003195 t:24.4s +tttg: c89/213 lr:0.003159 t:24.7s +tttg: c90/213 lr:0.003123 t:24.9s +tttg: c91/213 lr:0.003087 t:25.2s +tttg: c92/213 lr:0.003051 t:25.4s +tttg: c93/213 lr:0.003015 t:25.7s +tttg: c94/213 lr:0.002979 t:25.9s +tttg: c95/213 lr:0.002942 t:26.2s +tttg: c96/213 lr:0.002906 t:26.4s +tttg: c97/213 lr:0.002869 t:26.6s +tttg: c98/213 lr:0.002832 t:26.9s +tttg: c99/213 lr:0.002796 t:27.1s +tttg: c100/213 lr:0.002759 t:27.4s +tttg: c101/213 lr:0.002722 t:27.6s +tttg: c102/213 lr:0.002685 t:27.9s +tttg: c103/213 lr:0.002648 t:28.1s +tttg: c104/213 lr:0.002611 t:28.4s +tttg: c105/213 lr:0.002574 t:28.6s +tttg: c106/213 lr:0.002537 t:28.9s +tttg: c107/213 lr:0.002500 t:29.1s +tttg: c108/213 lr:0.002463 t:29.4s +tttg: c109/213 lr:0.002426 t:29.6s +tttg: c110/213 lr:0.002389 t:29.8s +tttg: c111/213 lr:0.002352 t:30.1s +tttg: c112/213 lr:0.002315 t:30.3s +tttg: c113/213 lr:0.002278 t:30.6s +tttg: c114/213 lr:0.002241 t:30.8s +tttg: c115/213 lr:0.002204 t:31.1s +tttg: c116/213 lr:0.002168 t:31.3s +tttg: c117/213 lr:0.002131 t:31.6s +tttg: c118/213 lr:0.002094 t:31.8s +tttg: c119/213 lr:0.002058 t:32.1s +tttg: c120/213 lr:0.002021 t:32.3s +tttg: c121/213 lr:0.001985 t:32.5s +tttg: c122/213 lr:0.001949 t:32.8s +tttg: c123/213 lr:0.001913 t:33.0s +tttg: c124/213 lr:0.001877 t:33.3s +tttg: c125/213 lr:0.001841 t:33.5s +tttg: c126/213 lr:0.001805 t:33.8s +tttg: c127/213 lr:0.001770 t:34.0s +tttg: c128/213 lr:0.001735 t:34.3s +tttg: c129/213 lr:0.001699 t:34.5s +tttg: c130/213 lr:0.001664 t:34.7s +tttg: c131/213 lr:0.001629 t:35.0s +tttg: c132/213 lr:0.001595 t:35.2s +tttg: c133/213 lr:0.001560 t:35.5s +tttg: c134/213 lr:0.001526 t:35.7s +tttg: c135/213 lr:0.001492 t:36.0s +tttg: c136/213 lr:0.001458 t:36.2s +tttg: c137/213 lr:0.001425 t:36.4s +tttg: c138/213 lr:0.001392 t:36.7s +tttg: c139/213 lr:0.001358 t:36.9s +tttg: c140/213 lr:0.001326 t:37.2s +tttg: c141/213 lr:0.001293 t:37.4s +tttg: c142/213 lr:0.001261 t:37.7s +tttg: c143/213 lr:0.001229 t:37.9s +tttg: c144/213 lr:0.001197 t:38.2s +tttg: c145/213 lr:0.001165 t:38.4s +tttg: c146/213 lr:0.001134 t:38.6s +tttg: c147/213 lr:0.001103 t:38.9s +tttg: c148/213 lr:0.001073 t:39.1s +tttg: c149/213 lr:0.001043 t:39.4s +tttg: c150/213 lr:0.001013 t:39.6s +tttg: c151/213 lr:0.000983 t:39.9s +tttg: c152/213 lr:0.000954 t:40.1s +tttg: c153/213 lr:0.000925 t:40.4s +tttg: c154/213 lr:0.000896 t:40.6s +tttg: c155/213 lr:0.000868 t:40.8s +tttg: c156/213 lr:0.000840 t:41.1s +tttg: c157/213 lr:0.000813 t:41.3s +tttg: c158/213 lr:0.000785 t:41.6s +tttg: c159/213 lr:0.000759 t:41.8s +tttg: c160/213 lr:0.000732 t:42.1s +tttg: c161/213 lr:0.000706 t:42.3s +tttg: c162/213 lr:0.000681 t:42.6s +tttg: c163/213 lr:0.000655 t:42.8s +tttg: c164/213 lr:0.000631 t:43.0s +tttg: c165/213 lr:0.000606 t:43.3s +tttg: c166/213 lr:0.000582 t:43.5s +tttg: c167/213 lr:0.000559 t:43.8s +tttg: c168/213 lr:0.000536 t:44.0s +tttg: c169/213 lr:0.000513 t:44.3s +tttg: c170/213 lr:0.000491 t:44.5s +tttg: c171/213 lr:0.000469 t:44.8s +tttg: c172/213 lr:0.000447 t:45.0s +tttg: c173/213 lr:0.000426 t:45.2s +tttg: c174/213 lr:0.000406 t:45.5s +tttg: c175/213 lr:0.000386 t:45.7s +tttg: c176/213 lr:0.000366 t:46.0s +tttg: c177/213 lr:0.000347 t:46.2s +tttg: c178/213 lr:0.000329 t:46.4s +tttg: c179/213 lr:0.000311 t:46.7s +tttg: c180/213 lr:0.000293 t:46.9s +tttg: c181/213 lr:0.000276 t:47.2s +tttg: c182/213 lr:0.000259 t:47.4s +tttg: c183/213 lr:0.000243 t:47.7s +tttg: c184/213 lr:0.000227 t:47.9s +tttg: c185/213 lr:0.000212 t:48.1s +tttg: c186/213 lr:0.000197 t:48.4s +tttg: c187/213 lr:0.000183 t:48.6s +tttg: c188/213 lr:0.000170 t:48.9s +tttg: c189/213 lr:0.000156 t:49.1s +tttg: c190/213 lr:0.000144 t:49.3s +tttg: c191/213 lr:0.000132 t:49.5s +tttg: c192/213 lr:0.000120 t:49.7s +tttg: c193/213 lr:0.000109 t:49.9s +tttg: c194/213 lr:0.000098 t:50.1s +tttg: c195/213 lr:0.000088 t:50.3s +tttg: c196/213 lr:0.000079 t:50.6s +tttg: c197/213 lr:0.000070 t:50.8s +tttg: c198/213 lr:0.000062 t:51.0s +tttg: c199/213 lr:0.000054 t:51.2s +tttg: c200/213 lr:0.000046 t:51.4s +tttg: c201/213 lr:0.000039 t:51.6s +tttg: c202/213 lr:0.000033 t:51.8s +tttg: c203/213 lr:0.000027 t:52.1s +tttg: c204/213 lr:0.000022 t:52.3s +tttg: c205/213 lr:0.000018 t:52.6s +tttg: c206/213 lr:0.000013 t:52.8s +tttg: c207/213 lr:0.000010 t:53.1s +tttg: c208/213 lr:0.000007 t:53.3s +tttg: c209/213 lr:0.000004 t:53.5s +tttg: c210/213 lr:0.000002 t:53.7s +tttg: c211/213 lr:0.000001 t:54.0s +tttg: c212/213 lr:0.000000 t:54.2s +ttpr: t:271.4s +ttp: b736/782 bl:2.6865 bb:1.0472 rl:2.7609 rb:1.0857 dl:2140-2165 gd:1 +ttp: b731/782 bl:2.7961 bb:1.0672 rl:2.7633 rb:1.0844 dl:2017-2041 gd:1 +ttp: b725/782 bl:2.7823 bb:1.0784 rl:2.7644 rb:1.0841 dl:1900-1915 gd:1 +ttp: b720/782 bl:2.8450 bb:1.0867 rl:2.7688 rb:1.0842 dl:1816-1832 gd:1 +ttp: b710/782 bl:2.7830 bb:1.0790 rl:2.7694 rb:1.0840 dl:1661-1673 gd:1 +ttp: b701/782 bl:2.7838 bb:1.0587 rl:2.7700 rb:1.0829 dl:1562-1572 gd:1 +ttp: b696/782 bl:2.8437 bb:1.0870 rl:2.7729 rb:1.0830 dl:1513-1522 gd:1 +ttp: b686/782 bl:2.8361 bb:1.0656 rl:2.7752 rb:1.0824 dl:1422-1432 gd:1 +ttp: b680/782 bl:2.8409 bb:1.0687 rl:2.7774 rb:1.0819 dl:1375-1383 gd:1 +ttp: b670/782 bl:2.8635 bb:1.0707 rl:2.7800 rb:1.0815 dl:1308-1315 gd:1 +ttp: b664/782 bl:2.7393 bb:1.0561 rl:2.7789 rb:1.0808 dl:1270-1275 gd:1 +ttp: b658/782 bl:2.8496 bb:1.0907 rl:2.7808 rb:1.0811 dl:1234-1239 gd:1 +ttp: b644/782 bl:2.7753 bb:1.0472 rl:2.7806 rb:1.0802 dl:1155-1160 gd:1 +ttp: b636/782 bl:2.8004 bb:1.0863 rl:2.7811 rb:1.0803 dl:1116-1120 gd:1 +ttp: b628/782 bl:2.8107 bb:1.0629 rl:2.7818 rb:1.0799 dl:1078-1082 gd:1 +ttp: b620/782 bl:2.8218 bb:1.0575 rl:2.7826 rb:1.0795 dl:1041-1046 gd:1 +ttp: b613/782 bl:2.8654 bb:1.0785 rl:2.7843 rb:1.0794 dl:1012-1016 gd:1 +ttp: b607/782 bl:2.7374 bb:1.0550 rl:2.7834 rb:1.0790 dl:986-990 gd:1 +ttp: b599/782 bl:2.7871 bb:1.0705 rl:2.7834 rb:1.0788 dl:954-958 gd:1 +ttp: b592/782 bl:2.8228 bb:1.0627 rl:2.7841 rb:1.0785 dl:930-933 gd:1 +ttp: b580/782 bl:2.7759 bb:1.0547 rl:2.7840 rb:1.0781 dl:891-894 gd:1 +ttp: b572/782 bl:2.9910 bb:1.1383 rl:2.7872 rb:1.0791 dl:865-868 gd:1 +ttp: b564/782 bl:2.9136 bb:1.1273 rl:2.7891 rb:1.0798 dl:840-843 gd:1 +ttp: b556/782 bl:2.8823 bb:1.1019 rl:2.7905 rb:1.0801 dl:815-818 gd:1 +ttp: b548/782 bl:2.8120 bb:1.0663 rl:2.7908 rb:1.0799 dl:793-795 gd:1 +ttp: b540/782 bl:2.7410 bb:1.0341 rl:2.7901 rb:1.0793 dl:771-774 gd:1 +ttp: b532/782 bl:2.8688 bb:1.0772 rl:2.7911 rb:1.0793 dl:752-754 gd:1 +ttp: b524/782 bl:2.8664 bb:1.0712 rl:2.7920 rb:1.0792 dl:732-735 gd:1 +ttp: b516/782 bl:2.9056 bb:1.0939 rl:2.7934 rb:1.0794 dl:713-715 gd:1 +ttp: b509/782 bl:2.8049 bb:1.0916 rl:2.7935 rb:1.0795 dl:695-698 gd:1 +ttp: b501/782 bl:2.8454 bb:1.0599 rl:2.7941 rb:1.0793 dl:677-680 gd:1 +ttp: b495/782 bl:2.8229 bb:1.0779 rl:2.7944 rb:1.0793 dl:664-666 gd:1 +ttp: b482/782 bl:2.8098 bb:1.1027 rl:2.7945 rb:1.0795 dl:637-639 gd:1 +ttp: b473/782 bl:2.8851 bb:1.0977 rl:2.7954 rb:1.0797 dl:618-620 gd:1 +ttp: b465/782 bl:2.8693 bb:1.0823 rl:2.7961 rb:1.0797 dl:602-604 gd:1 +ttp: b455/782 bl:2.8596 bb:1.0966 rl:2.7967 rb:1.0799 dl:584-586 gd:1 +ttp: b446/782 bl:2.8767 bb:1.1103 rl:2.7974 rb:1.0801 dl:568-569 gd:1 +ttp: b438/782 bl:2.7731 bb:1.0789 rl:2.7972 rb:1.0801 dl:553-555 gd:1 +ttp: b430/782 bl:2.8236 bb:1.0719 rl:2.7974 rb:1.0800 dl:539-540 gd:1 +ttp: b423/782 bl:2.8038 bb:1.0531 rl:2.7974 rb:1.0798 dl:526-528 gd:1 +ttp: b415/782 bl:2.9075 bb:1.1046 rl:2.7982 rb:1.0800 dl:513-514 gd:1 +ttp: b408/782 bl:2.8967 bb:1.1080 rl:2.7990 rb:1.0802 dl:501-503 gd:1 +ttp: b400/782 bl:2.8603 bb:1.0911 rl:2.7994 rb:1.0803 dl:489-490 gd:1 +ttp: b392/782 bl:2.8563 bb:1.1029 rl:2.7998 rb:1.0805 dl:476-478 gd:1 +ttp: b384/782 bl:2.9142 bb:1.1180 rl:2.8006 rb:1.0807 dl:464-466 gd:1 +ttp: b377/782 bl:2.8753 bb:1.1149 rl:2.8010 rb:1.0809 dl:454-455 gd:1 +ttp: b374/782 bl:2.8220 bb:1.0965 rl:2.8012 rb:1.0810 dl:450-452 gd:1 +ttp: b368/782 bl:2.9200 bb:1.1141 rl:2.8019 rb:1.0812 dl:441-443 gd:1 +ttp: b349/782 bl:2.9867 bb:1.1349 rl:2.8030 rb:1.0815 dl:415-417 gd:1 +ttp: b341/782 bl:2.9390 bb:1.1252 rl:2.8038 rb:1.0818 dl:404-406 gd:1 +ttp: b333/782 bl:2.9699 bb:1.1566 rl:2.8047 rb:1.0822 dl:394-395 gd:1 +ttp: b325/782 bl:2.9095 bb:1.1176 rl:2.8052 rb:1.0824 dl:384-385 gd:1 +ttp: b317/782 bl:2.9484 bb:1.1398 rl:2.8059 rb:1.0827 dl:373-374 gd:1 +ttp: b309/782 bl:2.9077 bb:1.1344 rl:2.8064 rb:1.0829 dl:363-364 gd:1 +ttp: b301/782 bl:2.8616 bb:1.1128 rl:2.8067 rb:1.0831 dl:353-354 gd:1 +ttp: b291/782 bl:3.0254 bb:1.1422 rl:2.8077 rb:1.0833 dl:341-342 gd:1 +ttp: b283/782 bl:2.8750 bb:1.1029 rl:2.8080 rb:1.0834 dl:332-333 gd:1 +ttp: b275/782 bl:2.8271 bb:1.0934 rl:2.8081 rb:1.0835 dl:323-324 gd:1 +ttp: b267/782 bl:2.9433 bb:1.1286 rl:2.8087 rb:1.0837 dl:314-315 gd:1 +ttp: b260/782 bl:2.9132 bb:1.1363 rl:2.8091 rb:1.0839 dl:306-307 gd:1 +ttp: b253/782 bl:2.8228 bb:1.1085 rl:2.8091 rb:1.0840 dl:298-299 gd:1 +ttp: b245/782 bl:2.9463 bb:1.1314 rl:2.8097 rb:1.0842 dl:290-291 gd:1 +ttp: b236/782 bl:2.9292 bb:1.1385 rl:2.8101 rb:1.0844 dl:281-282 gd:1 +ttp: b229/782 bl:2.9743 bb:1.1700 rl:2.8107 rb:1.0847 dl:274-275 gd:1 +ttp: b222/782 bl:2.9567 bb:1.1487 rl:2.8112 rb:1.0849 dl:267-268 gd:1 +ttp: b213/782 bl:3.0845 bb:1.2035 rl:2.8121 rb:1.0853 dl:258-259 gd:1 +ttp: b204/782 bl:2.9952 bb:1.1651 rl:2.8127 rb:1.0855 dl:250-251 gd:1 +ttp: b196/782 bl:2.9860 bb:1.1965 rl:2.8132 rb:1.0859 dl:243-244 gd:1 +ttp: b189/782 bl:3.0487 bb:1.2373 rl:2.8140 rb:1.0863 dl:237-237 gd:1 +ttp: b181/782 bl:2.9594 bb:1.1891 rl:2.8144 rb:1.0866 dl:230-230 gd:1 +ttp: b172/782 bl:3.0943 bb:1.2169 rl:2.8152 rb:1.0870 dl:222-223 gd:1 +ttp: b165/782 bl:3.0315 bb:1.1996 rl:2.8158 rb:1.0873 dl:216-217 gd:1 +ttp: b159/782 bl:3.0815 bb:1.2140 rl:2.8165 rb:1.0876 dl:211-212 gd:1 +ttp: b156/782 bl:2.9751 bb:1.1408 rl:2.8169 rb:1.0877 dl:208-209 gd:1 +ttp: b151/782 bl:2.8849 bb:1.1369 rl:2.8171 rb:1.0879 dl:204-205 gd:1 +ttp: b144/782 bl:2.9062 bb:1.1560 rl:2.8173 rb:1.0880 dl:199-200 gd:1 +ttp: b136/782 bl:3.0473 bb:1.2157 rl:2.8179 rb:1.0883 dl:192-193 gd:1 +ttp: b128/782 bl:2.9395 bb:1.1285 rl:2.8181 rb:1.0884 dl:186-187 gd:1 +ttp: b119/782 bl:2.8901 bb:1.1192 rl:2.8183 rb:1.0885 dl:179-180 gd:1 +ttp: b79/782 bl:3.1346 bb:1.2445 rl:2.8189 rb:1.0888 dl:149-150 gd:1 +ttp: b70/782 bl:3.1711 bb:1.2050 rl:2.8195 rb:1.0890 dl:142-143 gd:1 +ttp: b63/782 bl:3.1011 bb:1.2507 rl:2.8200 rb:1.0892 dl:137-138 gd:1 +ttp: b53/782 bl:3.2239 bb:1.2710 rl:2.8206 rb:1.0895 dl:129-130 gd:1 +ttp: b46/782 bl:3.2376 bb:1.2660 rl:2.8213 rb:1.0898 dl:123-124 gd:1 +ttp: b37/782 bl:3.1780 bb:1.2476 rl:2.8218 rb:1.0900 dl:116-117 gd:1 +ttp: b27/782 bl:3.1977 bb:1.2766 rl:2.8223 rb:1.0902 dl:107-108 gd:1 +ttp: b17/782 bl:3.2459 bb:1.2866 rl:2.8228 rb:1.0905 dl:98-99 gd:1 +ttp: b7/782 bl:3.3014 bb:1.2667 rl:2.8233 rb:1.0907 dl:84-86 gd:1 +quantized_ttt_phased val_loss:2.81277424 val_bpb:1.08891256 eval_time:472412ms +total_eval_time:472.4s diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1c.log b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1c.log new file mode 100644 index 0000000000..d7c1d2be27 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_1c.log @@ -0,0 +1,432 @@ +NCCL version 2.27.5+cuda12.9 +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: /workspace/parameter-golf/runs/ablation_1c + beta1: 0.9 + beta2: 0.95 + compressor: brotli + corrector_alpha: 0.1 + corrector_orders: 5,8,12 + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_only_quantized_path: /workspace/checkpoints/seed0/final_model.int6.ptz + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_doc_limit: 0 + global_ttt_epochs: 3 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.005 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/runs/ablation_1c/e0045366-3971-4746-935c-f1580f257abb.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/ablation_1c/final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: /workspace/checkpoints/seed0/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: e0045366-3971-4746-935c-f1580f257abb + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 32 + ttt_doc_limit: 0 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_heartbeat_seconds: 15.0 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +eval_only:using quantized checkpoint from /workspace/checkpoints/seed0/final_model.int6.ptz +eval_only: skipping serialize (already have quantized model) +diagnostic quantized val_loss:2.79886366 val_bpb:1.08349242 eval_time:12051ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (182.0s) + +beginning TTT eval timer +corrector: alpha=0.1 orders=[5, 8, 12] +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 +ttp: b779/782 bl:2.6478 bb:1.0785 rl:2.6478 rb:1.0785 dl:9037-11049 gd:0 +ttp: b770/782 bl:2.6617 bb:1.0533 rl:2.6522 rb:1.0703 dl:4479-4698 gd:0 +ttp: b764/782 bl:2.7699 bb:1.1006 rl:2.6761 rb:1.0766 dl:3639-3742 gd:0 +ttp: b758/782 bl:2.8820 bb:1.0874 rl:2.7064 rb:1.0782 dl:3108-3187 gd:0 +ttp: b752/782 bl:2.7701 bb:1.0628 rl:2.7137 rb:1.0764 dl:2740-2793 gd:0 +ttp: b746/782 bl:2.6826 bb:1.0563 rl:2.7108 rb:1.0745 dl:2459-2501 gd:0 +ttpp: pd:2448 gd:2000 t:202.2s +tttg: c1/213 lr:0.005000 t:0.6s +tttg: c2/213 lr:0.005000 t:0.9s +tttg: c3/213 lr:0.004999 t:1.1s +tttg: c4/213 lr:0.004998 t:1.4s +tttg: c5/213 lr:0.004996 t:1.7s +tttg: c6/213 lr:0.004993 t:2.0s +tttg: c7/213 lr:0.004990 t:2.3s +tttg: c8/213 lr:0.004987 t:2.6s +tttg: c9/213 lr:0.004982 t:2.9s +tttg: c10/213 lr:0.004978 t:3.2s +tttg: c11/213 lr:0.004973 t:3.4s +tttg: c12/213 lr:0.004967 t:3.7s +tttg: c13/213 lr:0.004961 t:4.0s +tttg: c14/213 lr:0.004954 t:4.3s +tttg: c15/213 lr:0.004946 t:4.6s +tttg: c16/213 lr:0.004938 t:4.9s +tttg: c17/213 lr:0.004930 t:5.2s +tttg: c18/213 lr:0.004921 t:5.5s +tttg: c19/213 lr:0.004912 t:5.8s +tttg: c20/213 lr:0.004902 t:6.1s +tttg: c21/213 lr:0.004891 t:6.4s +tttg: c22/213 lr:0.004880 t:6.7s +tttg: c23/213 lr:0.004868 t:7.0s +tttg: c24/213 lr:0.004856 t:7.3s +tttg: c25/213 lr:0.004844 t:7.6s +tttg: c26/213 lr:0.004830 t:7.8s +tttg: c27/213 lr:0.004817 t:8.1s +tttg: c28/213 lr:0.004803 t:8.5s +tttg: c29/213 lr:0.004788 t:8.7s +tttg: c30/213 lr:0.004773 t:9.0s +tttg: c31/213 lr:0.004757 t:9.3s +tttg: c32/213 lr:0.004741 t:9.6s +tttg: c33/213 lr:0.004724 t:9.9s +tttg: c34/213 lr:0.004707 t:10.2s +tttg: c35/213 lr:0.004689 t:10.6s +tttg: c36/213 lr:0.004671 t:10.9s +tttg: c37/213 lr:0.004653 t:11.2s +tttg: c38/213 lr:0.004634 t:11.4s +tttg: c39/213 lr:0.004614 t:11.7s +tttg: c40/213 lr:0.004594 t:12.0s +tttg: c41/213 lr:0.004574 t:12.3s +tttg: c42/213 lr:0.004553 t:12.6s +tttg: c43/213 lr:0.004531 t:12.9s +tttg: c44/213 lr:0.004509 t:13.2s +tttg: c45/213 lr:0.004487 t:13.6s +tttg: c46/213 lr:0.004464 t:13.9s +tttg: c47/213 lr:0.004441 t:14.1s +tttg: c48/213 lr:0.004418 t:14.5s +tttg: c49/213 lr:0.004394 t:14.7s +tttg: c50/213 lr:0.004369 t:15.0s +tttg: c51/213 lr:0.004345 t:15.4s +tttg: c52/213 lr:0.004319 t:15.7s +tttg: c53/213 lr:0.004294 t:15.9s +tttg: c54/213 lr:0.004268 t:16.2s +tttg: c55/213 lr:0.004241 t:16.5s +tttg: c56/213 lr:0.004215 t:16.9s +tttg: c57/213 lr:0.004187 t:17.1s +tttg: c58/213 lr:0.004160 t:17.4s +tttg: c59/213 lr:0.004132 t:17.7s +tttg: c60/213 lr:0.004104 t:18.1s +tttg: c61/213 lr:0.004075 t:18.4s +tttg: c62/213 lr:0.004046 t:18.6s +tttg: c63/213 lr:0.004017 t:18.9s +tttg: c64/213 lr:0.003987 t:19.2s +tttg: c65/213 lr:0.003957 t:19.5s +tttg: c66/213 lr:0.003927 t:19.8s +tttg: c67/213 lr:0.003897 t:20.1s +tttg: c68/213 lr:0.003866 t:20.5s +tttg: c69/213 lr:0.003835 t:20.7s +tttg: c70/213 lr:0.003803 t:21.1s +tttg: c71/213 lr:0.003771 t:21.3s +tttg: c72/213 lr:0.003739 t:21.6s +tttg: c73/213 lr:0.003707 t:21.9s +tttg: c74/213 lr:0.003674 t:22.2s +tttg: c75/213 lr:0.003642 t:22.5s +tttg: c76/213 lr:0.003608 t:22.9s +tttg: c77/213 lr:0.003575 t:23.1s +tttg: c78/213 lr:0.003542 t:23.4s +tttg: c79/213 lr:0.003508 t:23.7s +tttg: c80/213 lr:0.003474 t:24.0s +tttg: c81/213 lr:0.003440 t:24.3s +tttg: c82/213 lr:0.003405 t:24.6s +tttg: c83/213 lr:0.003371 t:24.9s +tttg: c84/213 lr:0.003336 t:25.2s +tttg: c85/213 lr:0.003301 t:25.5s +tttg: c86/213 lr:0.003265 t:25.8s +tttg: c87/213 lr:0.003230 t:26.1s +tttg: c88/213 lr:0.003195 t:26.4s +tttg: c89/213 lr:0.003159 t:26.7s +tttg: c90/213 lr:0.003123 t:27.0s +tttg: c91/213 lr:0.003087 t:27.3s +tttg: c92/213 lr:0.003051 t:27.7s +tttg: c93/213 lr:0.003015 t:27.9s +tttg: c94/213 lr:0.002979 t:28.3s +tttg: c95/213 lr:0.002942 t:28.5s +tttg: c96/213 lr:0.002906 t:28.9s +tttg: c97/213 lr:0.002869 t:29.1s +tttg: c98/213 lr:0.002832 t:29.4s +tttg: c99/213 lr:0.002796 t:29.7s +tttg: c100/213 lr:0.002759 t:30.1s +tttg: c101/213 lr:0.002722 t:30.3s +tttg: c102/213 lr:0.002685 t:30.7s +tttg: c103/213 lr:0.002648 t:30.9s +tttg: c104/213 lr:0.002611 t:31.3s +tttg: c105/213 lr:0.002574 t:31.5s +tttg: c106/213 lr:0.002537 t:31.8s +tttg: c107/213 lr:0.002500 t:32.1s +tttg: c108/213 lr:0.002463 t:32.5s +tttg: c109/213 lr:0.002426 t:32.7s +tttg: c110/213 lr:0.002389 t:33.1s +tttg: c111/213 lr:0.002352 t:33.3s +tttg: c112/213 lr:0.002315 t:33.7s +tttg: c113/213 lr:0.002278 t:33.9s +tttg: c114/213 lr:0.002241 t:34.2s +tttg: c115/213 lr:0.002204 t:34.5s +tttg: c116/213 lr:0.002168 t:34.9s +tttg: c117/213 lr:0.002131 t:35.1s +tttg: c118/213 lr:0.002094 t:35.4s +tttg: c119/213 lr:0.002058 t:35.7s +tttg: c120/213 lr:0.002021 t:36.1s +tttg: c121/213 lr:0.001985 t:36.3s +tttg: c122/213 lr:0.001949 t:36.6s +tttg: c123/213 lr:0.001913 t:36.9s +tttg: c124/213 lr:0.001877 t:37.3s +tttg: c125/213 lr:0.001841 t:37.5s +tttg: c126/213 lr:0.001805 t:37.8s +tttg: c127/213 lr:0.001770 t:38.1s +tttg: c128/213 lr:0.001735 t:38.5s +tttg: c129/213 lr:0.001699 t:38.7s +tttg: c130/213 lr:0.001664 t:39.1s +tttg: c131/213 lr:0.001629 t:39.3s +tttg: c132/213 lr:0.001595 t:39.7s +tttg: c133/213 lr:0.001560 t:39.9s +tttg: c134/213 lr:0.001526 t:40.2s +tttg: c135/213 lr:0.001492 t:40.5s +tttg: c136/213 lr:0.001458 t:40.8s +tttg: c137/213 lr:0.001425 t:41.1s +tttg: c138/213 lr:0.001392 t:41.5s +tttg: c139/213 lr:0.001358 t:41.7s +tttg: c140/213 lr:0.001326 t:42.1s +tttg: c141/213 lr:0.001293 t:42.4s +tttg: c142/213 lr:0.001261 t:42.7s +tttg: c143/213 lr:0.001229 t:42.9s +tttg: c144/213 lr:0.001197 t:43.3s +tttg: c145/213 lr:0.001165 t:43.5s +tttg: c146/213 lr:0.001134 t:43.8s +tttg: c147/213 lr:0.001103 t:44.1s +tttg: c148/213 lr:0.001073 t:44.5s +tttg: c149/213 lr:0.001043 t:44.7s +tttg: c150/213 lr:0.001013 t:45.0s +tttg: c151/213 lr:0.000983 t:45.3s +tttg: c152/213 lr:0.000954 t:45.7s +tttg: c153/213 lr:0.000925 t:45.9s +tttg: c154/213 lr:0.000896 t:46.3s +tttg: c155/213 lr:0.000868 t:46.5s +tttg: c156/213 lr:0.000840 t:46.9s +tttg: c157/213 lr:0.000813 t:47.1s +tttg: c158/213 lr:0.000785 t:47.4s +tttg: c159/213 lr:0.000759 t:47.7s +tttg: c160/213 lr:0.000732 t:48.1s +tttg: c161/213 lr:0.000706 t:48.5s +tttg: c162/213 lr:0.000681 t:48.7s +tttg: c163/213 lr:0.000655 t:49.1s +tttg: c164/213 lr:0.000631 t:49.3s +tttg: c165/213 lr:0.000606 t:49.6s +tttg: c166/213 lr:0.000582 t:49.9s +tttg: c167/213 lr:0.000559 t:50.4s +tttg: c168/213 lr:0.000536 t:50.6s +tttg: c169/213 lr:0.000513 t:51.0s +tttg: c170/213 lr:0.000491 t:51.3s +tttg: c171/213 lr:0.000469 t:51.8s +tttg: c172/213 lr:0.000447 t:52.1s +tttg: c173/213 lr:0.000426 t:52.4s +tttg: c174/213 lr:0.000406 t:52.7s +tttg: c175/213 lr:0.000386 t:52.9s +tttg: c176/213 lr:0.000366 t:53.2s +tttg: c177/213 lr:0.000347 t:53.5s +tttg: c178/213 lr:0.000329 t:53.7s +tttg: c179/213 lr:0.000311 t:54.0s +tttg: c180/213 lr:0.000293 t:54.2s +tttg: c181/213 lr:0.000276 t:54.5s +tttg: c182/213 lr:0.000259 t:54.8s +tttg: c183/213 lr:0.000243 t:55.0s +tttg: c184/213 lr:0.000227 t:55.6s +tttg: c185/213 lr:0.000212 t:55.8s +tttg: c186/213 lr:0.000197 t:56.1s +tttg: c187/213 lr:0.000183 t:56.3s +tttg: c188/213 lr:0.000170 t:56.6s +tttg: c189/213 lr:0.000156 t:56.9s +tttg: c190/213 lr:0.000144 t:57.1s +tttg: c191/213 lr:0.000132 t:57.4s +tttg: c192/213 lr:0.000120 t:57.6s +tttg: c193/213 lr:0.000109 t:57.9s +tttg: c194/213 lr:0.000098 t:58.2s +tttg: c195/213 lr:0.000088 t:58.4s +tttg: c196/213 lr:0.000079 t:58.7s +tttg: c197/213 lr:0.000070 t:58.9s +tttg: c198/213 lr:0.000062 t:59.2s +tttg: c199/213 lr:0.000054 t:59.5s +tttg: c200/213 lr:0.000046 t:59.7s +tttg: c201/213 lr:0.000039 t:60.0s +tttg: c202/213 lr:0.000033 t:60.3s +tttg: c203/213 lr:0.000027 t:60.5s +tttg: c204/213 lr:0.000022 t:60.8s +tttg: c205/213 lr:0.000018 t:61.0s +tttg: c206/213 lr:0.000013 t:61.3s +tttg: c207/213 lr:0.000010 t:61.5s +tttg: c208/213 lr:0.000007 t:61.8s +tttg: c209/213 lr:0.000004 t:62.1s +tttg: c210/213 lr:0.000002 t:62.3s +tttg: c211/213 lr:0.000001 t:62.6s +tttg: c212/213 lr:0.000000 t:62.8s +ttpr: t:268.3s +ttp: b736/782 bl:2.6732 bb:1.0420 rl:2.7080 rb:1.0720 dl:2140-2165 gd:1 +ttp: b734/782 bl:2.7684 bb:1.0557 rl:2.7121 rb:1.0709 dl:2091-2115 gd:1 +ttp: b726/782 bl:2.7989 bb:1.0651 rl:2.7172 rb:1.0705 dl:1915-1936 gd:1 +ttp: b718/782 bl:2.7774 bb:1.0706 rl:2.7203 rb:1.0705 dl:1773-1792 gd:1 +ttp: b708/782 bl:2.7214 bb:1.0458 rl:2.7204 rb:1.0694 dl:1639-1649 gd:1 +ttp: b700/782 bl:2.6770 bb:1.0448 rl:2.7186 rb:1.0684 dl:1552-1562 gd:1 +ttp: b695/782 bl:2.7847 bb:1.0796 rl:2.7211 rb:1.0688 dl:1504-1513 gd:1 +ttp: b681/782 bl:2.8191 bb:1.0703 rl:2.7245 rb:1.0689 dl:1383-1393 gd:1 +ttp: b676/782 bl:2.7925 bb:1.0670 rl:2.7267 rb:1.0688 dl:1347-1353 gd:1 +ttp: b666/782 bl:2.8205 bb:1.0601 rl:2.7294 rb:1.0685 dl:1282-1288 gd:1 +ttp: b658/782 bl:2.8128 bb:1.0766 rl:2.7318 rb:1.0688 dl:1234-1239 gd:1 +ttp: b654/782 bl:2.7329 bb:1.0374 rl:2.7318 rb:1.0679 dl:1209-1215 gd:1 +ttp: b646/782 bl:2.7755 bb:1.0747 rl:2.7329 rb:1.0681 dl:1166-1171 gd:1 +ttp: b634/782 bl:2.7029 bb:1.0434 rl:2.7322 rb:1.0675 dl:1105-1111 gd:1 +ttp: b625/782 bl:2.6726 bb:1.0041 rl:2.7309 rb:1.0661 dl:1064-1068 gd:1 +ttp: b616/782 bl:2.8516 bb:1.0874 rl:2.7334 rb:1.0665 dl:1024-1027 gd:1 +ttp: b608/782 bl:2.7345 bb:1.0320 rl:2.7334 rb:1.0658 dl:990-994 gd:1 +ttp: b600/782 bl:2.7891 bb:1.0588 rl:2.7344 rb:1.0657 dl:958-963 gd:1 +ttp: b592/782 bl:2.7809 bb:1.0470 rl:2.7352 rb:1.0654 dl:930-933 gd:1 +ttp: b583/782 bl:2.8040 bb:1.0937 rl:2.7364 rb:1.0658 dl:901-904 gd:1 +ttp: b576/782 bl:2.7837 bb:1.0484 rl:2.7371 rb:1.0655 dl:877-880 gd:1 +ttp: b561/782 bl:2.7131 bb:1.0640 rl:2.7368 rb:1.0655 dl:831-834 gd:1 +ttp: b547/782 bl:2.7345 bb:1.0327 rl:2.7367 rb:1.0650 dl:790-793 gd:1 +ttp: b538/782 bl:2.6910 bb:1.0408 rl:2.7361 rb:1.0647 dl:767-769 gd:1 +ttp: b530/782 bl:2.8103 bb:1.0402 rl:2.7371 rb:1.0644 dl:747-750 gd:1 +ttp: b522/782 bl:2.8278 bb:1.0870 rl:2.7382 rb:1.0647 dl:727-730 gd:1 +ttp: b509/782 bl:2.7539 bb:1.0717 rl:2.7384 rb:1.0647 dl:695-698 gd:1 +ttp: b502/782 bl:2.8383 bb:1.0658 rl:2.7395 rb:1.0648 dl:680-682 gd:1 +ttp: b493/782 bl:2.8467 bb:1.1163 rl:2.7407 rb:1.0653 dl:659-661 gd:1 +ttp: b484/782 bl:2.8023 bb:1.0695 rl:2.7413 rb:1.0654 dl:641-643 gd:1 +ttp: b475/782 bl:2.7325 bb:1.0244 rl:2.7413 rb:1.0649 dl:622-623 gd:1 +ttp: b465/782 bl:2.8160 bb:1.0622 rl:2.7420 rb:1.0649 dl:602-604 gd:1 +ttp: b453/782 bl:2.7669 bb:1.0618 rl:2.7422 rb:1.0649 dl:580-582 gd:1 +ttp: b436/782 bl:2.8516 bb:1.0698 rl:2.7431 rb:1.0649 dl:549-551 gd:1 +ttp: b423/782 bl:2.7451 bb:1.0311 rl:2.7432 rb:1.0646 dl:526-528 gd:1 +ttp: b414/782 bl:2.8277 bb:1.0901 rl:2.7438 rb:1.0648 dl:511-513 gd:1 +ttp: b403/782 bl:2.8178 bb:1.0530 rl:2.7444 rb:1.0647 dl:493-495 gd:1 +ttp: b394/782 bl:2.8985 bb:1.1178 rl:2.7455 rb:1.0651 dl:479-481 gd:1 +ttp: b386/782 bl:2.7311 bb:1.0669 rl:2.7454 rb:1.0651 dl:467-468 gd:1 +ttp: b380/782 bl:2.8471 bb:1.0781 rl:2.7461 rb:1.0652 dl:459-460 gd:1 +ttp: b376/782 bl:2.7207 bb:1.0448 rl:2.7459 rb:1.0651 dl:453-454 gd:1 +ttp: b370/782 bl:2.6852 bb:1.0447 rl:2.7455 rb:1.0650 dl:444-446 gd:1 +ttp: b364/782 bl:2.7361 bb:1.0672 rl:2.7455 rb:1.0650 dl:436-437 gd:1 +ttp: b357/782 bl:2.8666 bb:1.0846 rl:2.7462 rb:1.0651 dl:426-427 gd:1 +ttp: b349/782 bl:2.9265 bb:1.1120 rl:2.7473 rb:1.0654 dl:415-417 gd:1 +ttp: b340/782 bl:2.8251 bb:1.0928 rl:2.7478 rb:1.0656 dl:403-404 gd:1 +ttp: b332/782 bl:2.8340 bb:1.1004 rl:2.7483 rb:1.0658 dl:393-394 gd:1 +ttp: b324/782 bl:2.7773 bb:1.0593 rl:2.7484 rb:1.0657 dl:382-384 gd:1 +ttp: b315/782 bl:2.7347 bb:1.0753 rl:2.7483 rb:1.0658 dl:370-371 gd:1 +ttp: b306/782 bl:2.8858 bb:1.1417 rl:2.7490 rb:1.0661 dl:359-361 gd:1 +ttp: b297/782 bl:2.8043 bb:1.0624 rl:2.7493 rb:1.0661 dl:348-349 gd:1 +ttp: b289/782 bl:2.8420 bb:1.1250 rl:2.7498 rb:1.0664 dl:339-340 gd:1 +ttp: b280/782 bl:2.8310 bb:1.0987 rl:2.7501 rb:1.0665 dl:329-329 gd:1 +ttp: b271/782 bl:2.7868 bb:1.0740 rl:2.7503 rb:1.0666 dl:319-320 gd:1 +ttp: b249/782 bl:2.9086 bb:1.1585 rl:2.7509 rb:1.0669 dl:294-295 gd:1 +ttp: b241/782 bl:2.9112 bb:1.1277 rl:2.7516 rb:1.0672 dl:286-287 gd:1 +ttp: b233/782 bl:2.8714 bb:1.1282 rl:2.7520 rb:1.0674 dl:278-279 gd:1 +ttp: b226/782 bl:2.9464 bb:1.1462 rl:2.7527 rb:1.0677 dl:271-272 gd:1 +ttp: b219/782 bl:2.9050 bb:1.1334 rl:2.7533 rb:1.0679 dl:264-265 gd:1 +ttp: b210/782 bl:2.8564 bb:1.1242 rl:2.7536 rb:1.0681 dl:255-256 gd:1 +ttp: b202/782 bl:2.8758 bb:1.1368 rl:2.7541 rb:1.0684 dl:248-249 gd:1 +ttp: b195/782 bl:2.8455 bb:1.1137 rl:2.7544 rb:1.0685 dl:242-243 gd:1 +ttp: b186/782 bl:2.9492 bb:1.1783 rl:2.7550 rb:1.0688 dl:234-235 gd:1 +ttp: b178/782 bl:2.8614 bb:1.1412 rl:2.7553 rb:1.0691 dl:227-228 gd:1 +ttp: b172/782 bl:3.0143 bb:1.1854 rl:2.7561 rb:1.0694 dl:222-223 gd:1 +ttp: b166/782 bl:2.9846 bb:1.1507 rl:2.7567 rb:1.0696 dl:217-218 gd:1 +ttp: b162/782 bl:2.9601 bb:1.1486 rl:2.7573 rb:1.0698 dl:213-214 gd:1 +ttp: b157/782 bl:2.8290 bb:1.1151 rl:2.7575 rb:1.0700 dl:209-210 gd:1 +ttp: b149/782 bl:2.9876 bb:1.1778 rl:2.7581 rb:1.0703 dl:203-204 gd:1 +ttp: b143/782 bl:3.0231 bb:1.1974 rl:2.7588 rb:1.0706 dl:198-199 gd:1 +ttp: b135/782 bl:2.9358 bb:1.1438 rl:2.7592 rb:1.0708 dl:191-192 gd:1 +ttp: b126/782 bl:2.9370 bb:1.1933 rl:2.7597 rb:1.0710 dl:185-185 gd:1 +ttp: b117/782 bl:2.8706 bb:1.1506 rl:2.7599 rb:1.0712 dl:178-178 gd:1 +ttp: b110/782 bl:3.0384 bb:1.1797 rl:2.7605 rb:1.0715 dl:173-173 gd:1 +ttp: b103/782 bl:2.8946 bb:1.1203 rl:2.7608 rb:1.0716 dl:168-168 gd:1 +ttp: b95/782 bl:3.0176 bb:1.2288 rl:2.7614 rb:1.0719 dl:161-162 gd:1 +ttp: b84/782 bl:3.0391 bb:1.2245 rl:2.7619 rb:1.0722 dl:153-154 gd:1 +ttp: b76/782 bl:3.0669 bb:1.2304 rl:2.7625 rb:1.0725 dl:147-148 gd:1 +ttp: b68/782 bl:3.1166 bb:1.2108 rl:2.7631 rb:1.0727 dl:141-142 gd:1 +ttp: b60/782 bl:3.0838 bb:1.2376 rl:2.7637 rb:1.0730 dl:134-135 gd:1 +ttp: b52/782 bl:3.0705 bb:1.2014 rl:2.7642 rb:1.0732 dl:128-129 gd:1 +ttp: b43/782 bl:3.0123 bb:1.1979 rl:2.7646 rb:1.0734 dl:121-122 gd:1 +ttp: b15/782 bl:3.2498 bb:1.2433 rl:2.7652 rb:1.0736 dl:95-97 gd:1 +ttp: b5/782 bl:3.3456 bb:1.3050 rl:2.7658 rb:1.0738 dl:80-82 gd:1 +quantized_ttt_phased val_loss:2.77503778 val_bpb:1.07430360 eval_time:465764ms +total_eval_time:465.8s diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_summary.json b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_summary.json new file mode 100644 index 0000000000..061c54bdf9 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/ablation_summary.json @@ -0,0 +1,45 @@ +{ + "baseline_bpb": 1.07218477, + "runs": { + "1a": { + "config": { + "CORRECTOR_ALPHA": "0.3", + "CORRECTOR_ORDERS": "8" + }, + "bpb": 1.08876294, + "delta": -0.01657817000000006, + "eval_ms": 462840, + "artifact_bytes": 0, + "log": "/workspace/parameter-golf/runs/ablation_1a_log.txt" + }, + "1b": { + "config": { + "CORRECTOR_ALPHA": "0.3", + "CORRECTOR_ORDERS": "5,8,12" + }, + "bpb": 1.08891256, + "delta": -0.01672779000000002, + "eval_ms": 472412, + "artifact_bytes": 0, + "log": "/workspace/parameter-golf/runs/ablation_1b_log.txt" + }, + "1c": { + "config": { + "CORRECTOR_ALPHA": "0.1", + "CORRECTOR_ORDERS": "5,8,12" + }, + "bpb": 1.0743036, + "delta": -0.002118829999999905, + "eval_ms": 465764, + "artifact_bytes": 0, + "log": "/workspace/parameter-golf/runs/ablation_1c_log.txt" + } + }, + "best_config": "1c", + "best_config_details": { + "CORRECTOR_ALPHA": "0.1", + "CORRECTOR_ORDERS": "5,8,12" + }, + "best_delta": -0.002118829999999905, + "recommended_path": "fallback" +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/commit_sha.txt b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/commit_sha.txt new file mode 100644 index 0000000000..4295a9ed57 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/commit_sha.txt @@ -0,0 +1,4 @@ +1765afc7d62ce03a1219ca81cc92eea4fabdf343 +1765afc7d62ce03a1219ca81cc92eea4fabdf343 pipeline(#1610): auto-capture provenance, repo-type=model, exact-SHA pin, align README + run_all + author: amay + date: 2026-04-18 16:36:25 +0200 diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/env_fingerprint.txt b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/env_fingerprint.txt new file mode 100644 index 0000000000..6c062c5e57 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/env_fingerprint.txt @@ -0,0 +1,4 @@ +torch 2.9.1+cu128 +cuda (torch-reported) 12.8 +flash_attn_interface unknown +python 3.10.12 diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/hardware_info.txt b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/hardware_info.txt new file mode 100644 index 0000000000..9303f5bf61 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/provenance/hardware_info.txt @@ -0,0 +1,48 @@ +Sat Apr 18 22:04:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 24C P0 68W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 27C P0 68W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 28C P0 68W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 25C P0 69W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 25C P0 69W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 28C P0 69W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 28C P0 68W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 25C P0 69W / 700W | 0MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/reproduction_summary.json b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/reproduction_summary.json new file mode 100644 index 0000000000..68f11e4a3a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/reproduction_summary.json @@ -0,0 +1,12 @@ +{ + "stage": "gate_a", + "seed": 0, + "status": "PASS_TECHNICAL", + "note": "BPB 1.07218477 reproduces published 1.07216564 within 2e-5; artifact 15999394 bytes under 16000000 competition rule (exceeded internal 15997520 safety buffer only)", + "bpb": 1.07218477, + "published_bpb": 1.07216564, + "bpb_delta_vs_published": 0.00001913, + "eval_time_ms": 455945, + "artifact_bytes": 15999394, + "commit_sha": "1765afc7d62ce03a1219ca81cc92eea4fabdf343" +} diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/requirements.txt b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/requirements.txt new file mode 100644 index 0000000000..fb8c0d1628 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/requirements.txt @@ -0,0 +1,20 @@ +# Dependencies for this non-record package. The reproduction pins to PR #1610 +# at commit ca1919539dc6e328ea890cb03ad3ca1c5a84da55; the record-track +# submission upstream is the authoritative source for exact versions. +# +# FlashAttention 3 must be installed separately; see README.md. +# PyTorch 2.9.1+cu128 required; install via: +# uv pip install torch==2.9.1+cu128 --extra-index-url https://download.pytorch.org/whl/cu128 +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece +zstandard +brotli +python-minifier diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/submission.json b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/submission.json new file mode 100644 index 0000000000..87a815098e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/submission.json @@ -0,0 +1,32 @@ +{ + "name": "Ammer Ayach", + "github_id": "amrayach", + "track": "non_record_16mb", + "val_bpb": 1.07218477, + "reference_comparison": { + "note": "Single-seed reproduction; comparison is against PR #1610's published seed-0 number, not the 3-seed mean.", + "published_pr": "#1610", + "published_commit_sha": "ca1919539dc6e328ea890cb03ad3ca1c5a84da55", + "published_seed0_bpb": 1.07216564, + "our_seed0_bpb": 1.07218477, + "delta_vs_published_seed0": 0.00001913 + }, + "this_branch_commit_sha": "1765afc7d62ce03a1219ca81cc92eea4fabdf343", + "seeds_run": [0], + "hardware": "8x NVIDIA H100 80GB HBM3 (SXM5), CUDA 12.8, driver 570.211.01", + "wallclock": { + "train_seconds": "truncated at MAX_WALLCLOCK_SECONDS=600; stopped at step 4879 (by design in #1610)", + "eval_seconds": 455.9 + }, + "artifact_bytes": 15999394, + "contributions": [ + "Faithful reproduction of PR #1610 to within 2e-5 BPB on independent infrastructure.", + "Bounded negative result for a score-first posterior-corrector layered on phased LoRA TTT; all three tested (alpha, orders) configurations degrade BPB monotonically with blend weight.", + "Bug fix in the quantized-eval-only path of train_gpt.py enabling eval-only ablations against a preserved checkpoint." + ], + "external_supplementary": { + "url": "https://huggingface.co/amay01/parameter-golf-pr1610-reproduction-artifacts", + "description": "Optional external archive with preserved checkpoints (final_model.int6.ptz, final_model.pt), raw ablation intermediate artifacts, and environment manifest. Not required to read this PR.", + "primary_tarball_md5": "caf8adf63d8c80965f6671beba95d7aa" + } +} diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_gpt.py b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_gpt.py new file mode 100644 index 0000000000..96f1eac23a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_gpt.py @@ -0,0 +1,3470 @@ +import base64, brotli, collections, copy, fcntl, glob, io, json, lzma, math, os, shutil +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +class PrefixNgramCorrector: + """Prefix-only posterior corrector. State: scored prefix [0,t) only. + Call get_logit_bias() BEFORE scoring position t; call update(x_t) AFTER. + Laplace base guarantees full-vocab support (LEGALITY_SPEC.md Condition 2).""" + def __init__(self, V, alpha, orders): + self.V, self.alpha, self.orders = V, alpha, sorted(orders) + self.reset() + def reset(self): + self.uni = torch.ones(self.V, dtype=torch.int32) # Laplace: count >= 1 + self.ng = {n: {} for n in self.orders} # {n:{ctx_hash:{tok:cnt}}} + self.hist = [] + self._lu = self._lz = None + def _cache(self): + if self._lu is None: + self._lu = torch.log(self.uni.float()) + self._lz = torch.logsumexp(self._lu, 0).item() + def get_logit_bias(self): + """Return [V] float32 logit bias from prefix [0,t). Call BEFORE update.""" + self._cache() + b = self.alpha * (self._lu - self._lz) + if self.hist and self.orders: + delta = torch.zeros(self.V) + lu, lz = self._lu.tolist(), self._lz + for n in self.orders: + ctx = tuple(self.hist[-(n-1):]) if n > 1 else () + tbl = self.ng[n].get(ctx) + if tbl: + tot = sum(tbl.values()) + for tok, cnt in tbl.items(): + delta[tok] += self.alpha * ( + math.log((cnt+1)/(tot+self.V)) - (lu[tok]-lz) + ) + b = b + delta + return b + def update(self, token_id): + """Record scored token x_t. Call AFTER scoring position t.""" + t = int(token_id) + self.uni[t] += 1 + self._lu = self._lz = None + for n in self.orders: + ctx = tuple(self.hist[-(n-1):]) if n > 1 else () + d = self.ng[n].setdefault(ctx, {}) + d[t] = d.get(t, 0) + 1 + self.hist.append(t) + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.72)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + sliding_window_enabled = bool(int(os.environ.get("SLIDING_WINDOW_ENABLED", "0"))) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + embedding_dim = int(os.environ.get("EMBEDDING_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.022)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 32)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + ttt_output_dir = os.environ.get("TTT_OUTPUT_DIR", "") + ttt_heartbeat_seconds = float(os.environ.get("TTT_HEARTBEAT_SECONDS", 15.0)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.005)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 3)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float( + os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0) + ) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_doc_limit = int(os.environ.get("GLOBAL_TTT_DOC_LIMIT", 0)) + global_ttt_respect_doc_boundaries = bool( + int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1")) + ) + phased_ttt_enabled = bool(int(os.environ.get("PHASED_TTT_ENABLED", "0"))) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 10000)) + corrector_alpha = float(os.environ.get("CORRECTOR_ALPHA", "0.0")) + corrector_orders = os.environ.get("CORRECTOR_ORDERS", "8") + ttt_doc_limit = int(os.environ.get("TTT_DOC_LIMIT", 0)) + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 64)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 13.0)) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + datasets_dir = os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}") + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + tokenizer_path = os.path.join( + data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model" + ) + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + eval_only_path = os.environ.get("EVAL_ONLY_PATH", "") + eval_only_quantized_path = os.environ.get("EVAL_ONLY_QUANTIZED_PATH", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + eval_only_quantized_path + if eval_only_quantized_path + else ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + if self.bos_idx.size == 0: + self.bos_idx = np.array([0], dtype=np.int64) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.5 * c0) + aux1 = tl.where(c1 > 0, c1, 0.5 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + logits = self.forward_logits( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora, logit_bias=None): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if logit_bias is not None: + logits = logits + logit_bias # [B,1,V] broadcast — no dense [B,S,V] alloc + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +def classify_param(name): + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or ".proj." in name and ".mlp." not in name: + return "attn" + return "other" + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + a, b, c = 3.4445, -4.775, 2.0315 + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [ + { + "params": [base_model.lm_head.weight], + "lr": h.head_lr, + "base_lr": h.head_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + if base_model.lm_head is not None: + self.replicated_params.append(base_model.lm_head.weight) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_head is not None: + self.optimizer_head.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank"): + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + if model.tie_embeddings: + hook_module = ( + model.head_proj if model.head_proj is not None else model.final_norm + ) + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + cs = h.embed_clip_sigmas if "tok_emb" in name else h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + q, s = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1 + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = ( + q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + ).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def _unbank_state_dict(state_dict, num_layers): + sd = {} + n = num_layers + for k, v in state_dict.items(): + t = v.detach().cpu() + if k == "qo_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_q.weight"] = t[i] + sd[f"blocks.{i}.attn.proj.weight"] = t[n + i] + elif k == "kv_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_k.weight"] = t[i] + sd[f"blocks.{i}.attn.c_v.weight"] = t[n + i] + elif k == "mlp_up_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.fc.weight"] = t[i] + elif k == "mlp_down_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.proj.weight"] = t[i] + else: + sd[k] = t + return sd + + +def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): + sd = {} + n = num_layers + sd["qo_bank"] = torch.zeros(2 * n, model_dim, model_dim) + sd["kv_bank"] = torch.zeros(2 * n, kv_dim, model_dim) + sd["mlp_up_bank"] = torch.zeros(n, hidden_dim, model_dim) + sd["mlp_down_bank"] = torch.zeros(n, model_dim, hidden_dim) + for i in range(n): + sd["qo_bank"][i] = flat_sd[f"blocks.{i}.attn.c_q.weight"] + sd["qo_bank"][n + i] = flat_sd[f"blocks.{i}.attn.proj.weight"] + sd["kv_bank"][i] = flat_sd[f"blocks.{i}.attn.c_k.weight"] + sd["kv_bank"][n + i] = flat_sd[f"blocks.{i}.attn.c_v.weight"] + sd["mlp_up_bank"][i] = flat_sd[f"blocks.{i}.mlp.fc.weight"] + sd["mlp_down_bank"][i] = flat_sd[f"blocks.{i}.mlp.proj.weight"] + for k, v in flat_sd.items(): + if not ( + k.startswith("blocks.") + and any( + p in k + for p in [ + ".attn.c_q.", ".attn.c_k.", ".attn.c_v.", + ".attn.proj.", ".mlp.fc.", ".mlp.proj.", + ] + ) + ): + sd[k] = v + return sd + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + pyminify_exe = shutil.which("pyminify") + if pyminify_exe is not None: + minified = subprocess.run( + [ + pyminify_exe, + "--no-rename-locals", + "--no-hoist-literals", + "--remove-literal-statements", + "-", + ], + input=code_raw, capture_output=True, check=True, + ).stdout + else: + try: + import python_minifier + + minified = python_minifier.minify( + code, + remove_literal_statements=True, + rename_globals=False, + rename_locals=False, + hoist_literals=False, + ).encode("utf-8") + except ImportError: + minified = ast.unparse(ast.parse(code)).encode("utf-8") + compressed = brotli.compress(minified, quality=11) + encoded = base64.b85encode(compressed) + wrapper = b'import brotli as B,base64 as A\nexec(B.decompress(A.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding(h, device, val_data, base_model, forward_logits_fn=None, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + run_forward_logits = base_model.forward_logits if forward_logits_fn is None else forward_logits_fn + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + context_size = seq_len - stride + window_starts = [ws for ws in range(0, total_tokens, stride) + if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + total_batches = (len(my_windows) + batch_seqs - 1) // batch_seqs + is_master = h.rank == 0 + cu_bucket = 64 + t_sw_start = time.perf_counter() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_idx = bi // batch_seqs + if is_master and (batch_idx % 50 == 0 or batch_idx == total_batches - 1): + elapsed = time.perf_counter() - t_sw_start + rl = float(loss_sum.item() / token_count.item()) if token_count.item() > 0 else 0.0 + rb = float((rl / math.log(2.0)) * token_count.item() / byte_count.item()) if byte_count.item() > 0 else 0.0 + log(f"sliding_progress: batch {batch_idx+1}/{total_batches} " + f"tokens:{int(token_count.item())} running_loss:{rl:.4f} running_bpb:{rb:.4f} " + f"elapsed:{elapsed:.1f}s") + batch_ws = my_windows[bi:bi + batch_seqs] + x_parts = [] + y_parts = [] + cu_starts = [] + score_ranges = [] + offset = 0 + for ws in batch_ws: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk_cpu = val_data.val_tokens[ws:end + 1] + bos_pos = (chunk_cpu[:-1] == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_pos or bos_pos[0] != 0: + bos_pos = [0] + bos_pos + cu_starts.extend(offset + pos for pos in bos_pos) + chunk = chunk_cpu.to(dtype=torch.int64, device=device) + x_parts.append(chunk[:-1]) + y_parts.append(chunk[1:]) + score_ranges.append((offset, wlen, ws)) + offset += wlen + x_cat = torch.cat(x_parts, dim=0)[None] + y_cat = torch.cat(y_parts, dim=0) + boundaries = cu_starts + [offset] + padded_len = get_next_multiple_of_n(len(boundaries), cu_bucket) + cu_seqlens = torch.full((padded_len,), offset, dtype=torch.int32, device=device) + cu_seqlens[:len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = run_forward_logits(x_cat, cu_seqlens=cu_seqlens, max_seqlen=seq_len) + flat_nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_cat, + reduction="none", + ) + flat_x = x_cat.reshape(-1) + for off, wlen, ws in score_ranges: + s = 0 if ws == 0 else context_size + lo = off + s + hi = off + wlen + scored_nll = flat_nll[lo:hi].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(hi - lo) + tgt = y_cat[lo:hi] + prev = flat_x[lo:hi] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _get_prefix_doc_tokens(h, val_data): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + all_tokens = val_data.val_tokens + doc_limit = h.global_ttt_doc_limit + if doc_limit <= 0 and h.val_doc_fraction >= 1.0: + return all_tokens + docs = _find_docs(all_tokens) + if doc_limit > 0: + keep_docs = min(len(docs), doc_limit) + else: + keep_docs = max(1, int(round(len(docs) * h.val_doc_fraction))) + end = docs[keep_docs - 1][0] + docs[keep_docs - 1][1] + return all_tokens[:end] + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + return _build_ttt_global_batches_from_sorted(global_doc_entries, h, ascending=ascending) + + +def _build_ttt_global_batches_from_sorted(global_doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.ttt_doc_limit > 0: + return doc_entries[: min(len(doc_entries), h.ttt_doc_limit)] + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _split_doc_entries_for_phased(doc_entries, prefix_docs): + prefix_docs = max(0, min(len(doc_entries), int(prefix_docs))) + return doc_entries[:prefix_docs], doc_entries[prefix_docs:] + + +def _doc_entries_end_token(doc_entries): + if not doc_entries: + return 0 + doc_start, doc_len = doc_entries[-1][1] + return doc_start + doc_len + + +def _copy_module_state_(dst, src): + # Keep the per-doc model in lockstep with the persistent global model + # without rebuilding or recompiling the doc-scoring graph. + with torch.no_grad(): + for dst_p, src_p in zip(dst.parameters(), src.parameters()): + dst_p.copy_(src_p) + for dst_b, src_b in zip(dst.buffers(), src.buffers()): + dst_b.copy_(src_b) + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, correctors=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + # Corrector state is global within a phase (LEGALITY_SPEC.md §Q3): + # reset happens only at SGD boundary, NOT per document-batch. + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + # Corrector: compute per-doc [V] bias → [B,1,V] before forward pass + _logit_bias = None + if correctors is not None: + _cb = [correctors[_b].get_logit_bias() if active[_b] else torch.zeros(h.vocab_size) for _b in range(bsz)] + _logit_bias = torch.stack(_cb).unsqueeze(1).to(device=device, dtype=torch.bfloat16) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora, logit_bias=_logit_bias) + with torch.no_grad(): + _accumulate_bpb( + per_tok_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + ) + # Corrector: update state with scored tokens (score-before-update) + if correctors is not None: + _y_cpu = y.cpu() + for _b in range(bsz): + if not active[_b]: + continue + _co, _cl = int(chunk_offsets_cpu[_b]), int(chunk_lens_cpu[_b]) + for _tok in _y_cpu[_b, _co:_co + _cl].tolist(): + correctors[_b].update(_tok) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora, logit_bias=_logit_bias) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= prefix_doc_limit: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:prefix_doc_limit] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: pd:{prefix_done} gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + global_ttt_done = True + if correctors is not None: + for _c in correctors: + _c.reset() # SGD changed base model; stale prefix stats discarded + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + +def eval_val_ttt_lora_stats( + h, + base_model, + device, + val_data, + forward_ttt_train, + doc_entries=None, + global_batches_sorted=None, + counter_tag="lora", + progress_filename="progress.jsonl", + progress_tag="ttt_progress", +): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + if doc_entries is None: + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + log( + f"{progress_tag}:docs:{len(doc_entries)} rank:{h.ttt_lora_rank} lr:{h.ttt_lora_lr} chunk:{h.ttt_chunk_size}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + if global_batches_sorted is None: + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{counter_tag}_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path] + dist.broadcast_object_list(path_list, src=0) + counter_path = path_list[0] + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + last_heartbeat = t_start + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + progress_f = None + if h.ttt_output_dir and h.rank == 0 and progress_filename: + os.makedirs(h.ttt_output_dir, exist_ok=True) + progress_f = open(os.path.join(h.ttt_output_dir, progress_filename), "w") + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + now = time.perf_counter() + if now - last_heartbeat >= h.ttt_heartbeat_seconds: + doc_lens = [dl for _, dl in batch] + log( + f"{progress_tag}_heartbeat: batch {orig_batch_idx+1}/{queue_len} " + f"chunk {ci+1}/{max_nc} doc_len:{min(doc_lens)}-{max(doc_lens)} " + f"elapsed:{now - t_start:.1f}s" + ) + last_heartbeat = now + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + with torch.no_grad(): + _accumulate_bpb( + per_tok_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"{progress_tag}: batch {batch_num}/{queue_len} batch_loss:{b_loss:.4f} " + f"batch_bpb:{b_bpb:.4f} running_loss:{r_loss:.4f} running_bpb:{r_bpb:.4f} " + f"doc_len:{min(doc_lens)}-{max(doc_lens)}" + ) + if progress_f is not None: + progress_f.write( + json.dumps({ + "batch": batch_num, "total_batches": queue_len, + "batch_loss": round(b_loss, 8), "batch_bpb": round(b_bpb, 8), + "running_loss": round(r_loss, 8), "running_bpb": round(r_bpb, 8), + "doc_len_min": min(doc_lens), "doc_len_max": max(doc_lens), + "chunk_size": chunk_size, + "elapsed_s": round(elapsed, 3), + }) + "\n" + ) + progress_f.flush() + if bsz != reusable_lora.bsz: + del cur_lora, cur_opt + finally: + if progress_f is not None: + progress_f.close() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + val_loss, val_bpb = _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + return val_loss, val_bpb, loss_sum, byte_sum, token_count + + +def eval_val_ttt_lora(h, base_model, device, val_data, forward_ttt_train): + val_loss, val_bpb, _, _, _ = eval_val_ttt_lora_stats( + h, base_model, device, val_data, forward_ttt_train + ) + return val_loss, val_bpb + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + ttt_only_eval = bool(int(os.environ.get("TTT_ONLY_EVAL", "0"))) + quantized_eval_only = bool(h.eval_only_quantized_path and not h.eval_only_path) + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + if h.eval_only_path: + log(f"eval_only:loading checkpoint from {h.eval_only_path}") + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + base_model.load_state_dict(torch.load(h.eval_only_path, map_location=device)) + if h.num_loops > 0: + base_model.looping_active = True + if ttt_only_eval: + compiled_model = None + compiled_forward_logits = None + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + elif quantized_eval_only: + log(f"eval_only:using quantized checkpoint from {h.eval_only_quantized_path}") + base_model = None + compiled_model = None + compiled_forward_logits = None + else: + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + _skip_training = bool(h.eval_only_path or quantized_eval_only) + torch._dynamo.reset() + if not ttt_only_eval: + if not quantized_eval_only: + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if not _skip_training: + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + else: + log("eval_only: skipping serialize (already have quantized model)") + if not os.path.exists(h.quantized_model_path): + log("eval_only: no quantized model found, running serialize anyway") + if base_model is None: + raise FileNotFoundError( + f"quantized checkpoint not found at {h.quantized_model_path}" + ) + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.sliding_window_enabled: + timed_eval( + "diagnostic quantized_sliding_window", + eval_val_sliding, + h, + device, + val_data, + eval_model, + forward_logits_fn=compiled_forward_logits, + ) + else: + log( + "ttt_only_eval: skipping pre-quant/quantized diagnostics and using " + "the eval-only checkpoint directly for TTT" + ) + if h.ttt_enabled: + if not ttt_only_eval and not quantized_eval_only: + del eval_model, compiled_model + torch._dynamo.reset() + torch.cuda.empty_cache() + if ttt_only_eval: + if base_model is not None: + ttt_model = base_model + else: + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + else: + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + base_ttt_requires_grad = False + for p in ttt_model.parameters(): + p.requires_grad_(base_ttt_requires_grad) + + def _prepare_ttt_eval_model(model): + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + _prepare_ttt_eval_model(ttt_model) + scoring_model = ttt_model + + def _fwd_ttt_inner(input_ids, target_ids, lora, logit_bias=None): + return scoring_model.forward_ttt(input_ids, target_ids, lora=lora, logit_bias=logit_bias) + + _fwd_ttt_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora, logit_bias=None): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora, logit_bias=logit_bias) + + _ttt_debug_bypass = bool(os.environ.get("TTT_DEBUG_BYPASS")) + if _ttt_debug_bypass: + def _fwd_ttt_bypass(input_ids, target_ids, lora, logit_bias=None): + logits = scoring_model.forward_logits(input_ids) + dummy = lora.q_loras[0].B.sum() * 0 + logits = logits + dummy + if logit_bias is not None: + logits = logits + logit_bias + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + fwd_ttt_compiled = _fwd_ttt_bypass + log("ttt_lora:DEBUG BYPASS active - using forward_logits directly (no compile warmup)") + else: + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + # BEGIN warmup synthetic tokens (LEGALITY_SPEC: no val-token touch pre-eval) + # _warmup_gen is a device-local generator; it does NOT mutate global + # torch RNG state, so downstream training/eval determinism is untouched. + _warmup_gen = torch.Generator(device=device).manual_seed(0) + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, scoring_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + row_w = torch.randint( + 0, h.vocab_size, (ctx_len + 1,), + device=device, dtype=torch.int64, generator=_warmup_gen, + ) + xw = row_w[:ctx_len].unsqueeze(0).expand(bsz, -1).contiguous() + yw = row_w[1 : ctx_len + 1].unsqueeze(0).expand(bsz, -1).contiguous() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + if h.corrector_alpha > 0: + dummy_bias = torch.zeros( + bsz, 1, h.vocab_size, device=device, dtype=torch.bfloat16, + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl2 = fwd_ttt_compiled(xw, yw, lora=wl, logit_bias=dummy_bias) + ptl2[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + # END warmup synthetic tokens + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + _correctors = None + if h.phased_ttt_enabled and h.corrector_alpha > 0.0: + _orders = [int(o) for o in h.corrector_orders.split(",") if o.strip()] + _correctors = [PrefixNgramCorrector(h.vocab_size, h.corrector_alpha, _orders) + for _ in range(h.ttt_batch_size)] + log(f"corrector: alpha={h.corrector_alpha} orders={_orders}") + if h.phased_ttt_enabled: + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, + correctors=_correctors + ) + else: + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + if h.phased_ttt_enabled: + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + else: + log( + f"quantized_ttt_lora val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 16 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ).stdout, + console=False, + ) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_seed0.log b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_seed0.log new file mode 100644 index 0000000000..14537a0381 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-19_pr1610_reproduction_corrector_negative/train_seed0.log @@ -0,0 +1,497 @@ +=== Stage 2: Gate A (seed 0) === Sat Apr 18 20:24:23 UTC 2026 +Published BPB: 1.07216564 | Kill if > 1.07516564 (published + 0.003) +Artifact limit: 15997520 bytes +SHA check: 1765afc7d62ce03a1219ca81cc92eea4fabdf343: OK + +Starting train+eval (torchrun 8xH100)... +NCCL version 2.27.5+cuda12.9 +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: /workspace/parameter-golf/runs/seed0 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + corrector_alpha: 0.0 + corrector_orders: 8 + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + eval_only_path: + eval_only_quantized_path: + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_doc_limit: 0 + global_ttt_epochs: 3 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.005 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 13.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/runs/seed0/0df2c6d6-6b99-49f7-b3b0-e3eee51a1d83.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/runs/seed0/final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: /workspace/parameter-golf/runs/seed0/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 0df2c6d6-6b99-49f7-b3b0-e3eee51a1d83 + scalar_lr: 0.02 + seed: 0 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 32 + ttt_doc_limit: 0 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_heartbeat_seconds: 15.0 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_output_dir: + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 13s, effective=587000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0086 val_bpb: 3.4874 +1/20000 train_loss: 9.0089 train_time: 0.0m tok/s: 16340826 +2/20000 train_loss: 12.2895 train_time: 0.0m tok/s: 12147844 +3/20000 train_loss: 11.2747 train_time: 0.0m tok/s: 10201309 +4/20000 train_loss: 9.6697 train_time: 0.0m tok/s: 9414717 +5/20000 train_loss: 8.2716 train_time: 0.0m tok/s: 8610708 +500/20000 train_loss: 3.2579 train_time: 0.8m tok/s: 8274271 +1000/20000 train_loss: 3.0137 train_time: 1.6m tok/s: 8236357 +1500/20000 train_loss: 3.0172 train_time: 2.4m tok/s: 8223568 +2000/20000 train_loss: 2.9823 train_time: 3.2m tok/s: 8219490 +layer_loop:enabled step:2147 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0652 train_time: 4.2m tok/s: 7718360 +3000/20000 train_loss: 2.9012 train_time: 5.4m tok/s: 7268579 +3500/20000 train_loss: 2.9743 train_time: 6.6m tok/s: 6979091 +4000/20000 train_loss: 2.9014 train_time: 7.7m tok/s: 6775713 +4000/20000 val_loss: 2.8773 val_bpb: 1.1139 +4500/20000 train_loss: 2.8552 train_time: 8.9m tok/s: 6626801 +4879/20000 val_loss: 2.7712 val_bpb: 1.0728 +stopping_early: wallclock_cap train_time: 587077ms step: 4879/20000 +peak memory allocated: 40029 MiB reserved: 44036 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.76987839 val_bpb:1.07227169 eval_time:7020ms +Serialized model: 135409136 bytes +Code size (uncompressed): 144707 bytes +Code size (compressed): 29462 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 12.7s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15969932 bytes +Total submission size quantized+brotli: 15999394 bytes +diagnostic quantized val_loss:2.79886366 val_bpb:1.08349242 eval_time:44357ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (120.3s) + +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 +ttp: b782/782 bl:2.5613 bb:1.0341 rl:2.5613 rb:1.0341 dl:26524-79464 gd:0 +ttpp: pd:2448 gd:2000 t:219.9s +tttg: c1/213 lr:0.005000 t:2.1s +tttg: c2/213 lr:0.005000 t:2.3s +tttg: c3/213 lr:0.004999 t:2.5s +tttg: c4/213 lr:0.004998 t:2.7s +tttg: c5/213 lr:0.004996 t:3.0s +tttg: c6/213 lr:0.004993 t:3.2s +tttg: c7/213 lr:0.004990 t:3.4s +tttg: c8/213 lr:0.004987 t:3.7s +tttg: c9/213 lr:0.004982 t:3.9s +tttg: c10/213 lr:0.004978 t:4.1s +tttg: c11/213 lr:0.004973 t:4.4s +tttg: c12/213 lr:0.004967 t:4.6s +tttg: c13/213 lr:0.004961 t:4.8s +tttg: c14/213 lr:0.004954 t:5.0s +tttg: c15/213 lr:0.004946 t:5.3s +tttg: c16/213 lr:0.004938 t:5.5s +tttg: c17/213 lr:0.004930 t:5.7s +tttg: c18/213 lr:0.004921 t:6.0s +tttg: c19/213 lr:0.004912 t:6.2s +tttg: c20/213 lr:0.004902 t:6.4s +tttg: c21/213 lr:0.004891 t:6.7s +tttg: c22/213 lr:0.004880 t:6.9s +tttg: c23/213 lr:0.004868 t:7.1s +tttg: c24/213 lr:0.004856 t:7.3s +tttg: c25/213 lr:0.004844 t:7.6s +tttg: c26/213 lr:0.004830 t:7.8s +tttg: c27/213 lr:0.004817 t:8.0s +tttg: c28/213 lr:0.004803 t:8.2s +tttg: c29/213 lr:0.004788 t:8.5s +tttg: c30/213 lr:0.004773 t:8.7s +tttg: c31/213 lr:0.004757 t:8.9s +tttg: c32/213 lr:0.004741 t:9.2s +tttg: c33/213 lr:0.004724 t:9.5s +tttg: c34/213 lr:0.004707 t:9.7s +tttg: c35/213 lr:0.004689 t:9.9s +tttg: c36/213 lr:0.004671 t:10.1s +tttg: c37/213 lr:0.004653 t:10.4s +tttg: c38/213 lr:0.004634 t:10.6s +tttg: c39/213 lr:0.004614 t:10.8s +tttg: c40/213 lr:0.004594 t:11.1s +tttg: c41/213 lr:0.004574 t:11.3s +tttg: c42/213 lr:0.004553 t:11.5s +tttg: c43/213 lr:0.004531 t:11.8s +tttg: c44/213 lr:0.004509 t:12.0s +tttg: c45/213 lr:0.004487 t:12.2s +tttg: c46/213 lr:0.004464 t:12.5s +tttg: c47/213 lr:0.004441 t:12.7s +tttg: c48/213 lr:0.004418 t:13.0s +tttg: c49/213 lr:0.004394 t:13.2s +tttg: c50/213 lr:0.004369 t:13.4s +tttg: c51/213 lr:0.004345 t:13.6s +tttg: c52/213 lr:0.004319 t:13.9s +tttg: c53/213 lr:0.004294 t:14.1s +tttg: c54/213 lr:0.004268 t:14.3s +tttg: c55/213 lr:0.004241 t:14.6s +tttg: c56/213 lr:0.004215 t:14.8s +tttg: c57/213 lr:0.004187 t:15.0s +tttg: c58/213 lr:0.004160 t:15.3s +tttg: c59/213 lr:0.004132 t:15.5s +tttg: c60/213 lr:0.004104 t:15.7s +tttg: c61/213 lr:0.004075 t:15.9s +tttg: c62/213 lr:0.004046 t:16.2s +tttg: c63/213 lr:0.004017 t:16.4s +tttg: c64/213 lr:0.003987 t:16.6s +tttg: c65/213 lr:0.003957 t:16.9s +tttg: c66/213 lr:0.003927 t:17.1s +tttg: c67/213 lr:0.003897 t:17.3s +tttg: c68/213 lr:0.003866 t:17.5s +tttg: c69/213 lr:0.003835 t:17.8s +tttg: c70/213 lr:0.003803 t:18.0s +tttg: c71/213 lr:0.003771 t:18.2s +tttg: c72/213 lr:0.003739 t:18.4s +tttg: c73/213 lr:0.003707 t:18.7s +tttg: c74/213 lr:0.003674 t:18.9s +tttg: c75/213 lr:0.003642 t:19.1s +tttg: c76/213 lr:0.003608 t:19.3s +tttg: c77/213 lr:0.003575 t:19.6s +tttg: c78/213 lr:0.003542 t:19.8s +tttg: c79/213 lr:0.003508 t:20.0s +tttg: c80/213 lr:0.003474 t:20.3s +tttg: c81/213 lr:0.003440 t:20.5s +tttg: c82/213 lr:0.003405 t:20.7s +tttg: c83/213 lr:0.003371 t:21.0s +tttg: c84/213 lr:0.003336 t:21.2s +tttg: c85/213 lr:0.003301 t:21.4s +tttg: c86/213 lr:0.003265 t:21.6s +tttg: c87/213 lr:0.003230 t:21.9s +tttg: c88/213 lr:0.003195 t:22.1s +tttg: c89/213 lr:0.003159 t:22.3s +tttg: c90/213 lr:0.003123 t:22.6s +tttg: c91/213 lr:0.003087 t:22.8s +tttg: c92/213 lr:0.003051 t:23.1s +tttg: c93/213 lr:0.003015 t:23.3s +tttg: c94/213 lr:0.002979 t:23.5s +tttg: c95/213 lr:0.002942 t:23.8s +tttg: c96/213 lr:0.002906 t:24.0s +tttg: c97/213 lr:0.002869 t:24.2s +tttg: c98/213 lr:0.002832 t:24.4s +tttg: c99/213 lr:0.002796 t:24.7s +tttg: c100/213 lr:0.002759 t:24.9s +tttg: c101/213 lr:0.002722 t:25.1s +tttg: c102/213 lr:0.002685 t:25.4s +tttg: c103/213 lr:0.002648 t:25.6s +tttg: c104/213 lr:0.002611 t:25.8s +tttg: c105/213 lr:0.002574 t:26.1s +tttg: c106/213 lr:0.002537 t:26.3s +tttg: c107/213 lr:0.002500 t:26.5s +tttg: c108/213 lr:0.002463 t:26.7s +tttg: c109/213 lr:0.002426 t:27.0s +tttg: c110/213 lr:0.002389 t:27.2s +tttg: c111/213 lr:0.002352 t:27.4s +tttg: c112/213 lr:0.002315 t:27.7s +tttg: c113/213 lr:0.002278 t:28.0s +tttg: c114/213 lr:0.002241 t:28.2s +tttg: c115/213 lr:0.002204 t:28.4s +tttg: c116/213 lr:0.002168 t:28.7s +tttg: c117/213 lr:0.002131 t:28.9s +tttg: c118/213 lr:0.002094 t:29.1s +tttg: c119/213 lr:0.002058 t:29.4s +tttg: c120/213 lr:0.002021 t:29.6s +tttg: c121/213 lr:0.001985 t:29.9s +tttg: c122/213 lr:0.001949 t:30.1s +tttg: c123/213 lr:0.001913 t:30.3s +tttg: c124/213 lr:0.001877 t:30.6s +tttg: c125/213 lr:0.001841 t:30.8s +tttg: c126/213 lr:0.001805 t:31.0s +tttg: c127/213 lr:0.001770 t:31.3s +tttg: c128/213 lr:0.001735 t:31.5s +tttg: c129/213 lr:0.001699 t:31.7s +tttg: c130/213 lr:0.001664 t:31.9s +tttg: c131/213 lr:0.001629 t:32.2s +tttg: c132/213 lr:0.001595 t:32.4s +tttg: c133/213 lr:0.001560 t:32.6s +tttg: c134/213 lr:0.001526 t:32.9s +tttg: c135/213 lr:0.001492 t:33.1s +tttg: c136/213 lr:0.001458 t:33.3s +tttg: c137/213 lr:0.001425 t:33.6s +tttg: c138/213 lr:0.001392 t:33.8s +tttg: c139/213 lr:0.001358 t:34.0s +tttg: c140/213 lr:0.001326 t:34.2s +tttg: c141/213 lr:0.001293 t:34.5s +tttg: c142/213 lr:0.001261 t:34.7s +tttg: c143/213 lr:0.001229 t:34.9s +tttg: c144/213 lr:0.001197 t:35.2s +tttg: c145/213 lr:0.001165 t:35.4s +tttg: c146/213 lr:0.001134 t:35.6s +tttg: c147/213 lr:0.001103 t:35.8s +tttg: c148/213 lr:0.001073 t:36.1s +tttg: c149/213 lr:0.001043 t:36.3s +tttg: c150/213 lr:0.001013 t:36.5s +tttg: c151/213 lr:0.000983 t:36.8s +tttg: c152/213 lr:0.000954 t:37.0s +tttg: c153/213 lr:0.000925 t:37.2s +tttg: c154/213 lr:0.000896 t:37.5s +tttg: c155/213 lr:0.000868 t:37.7s +tttg: c156/213 lr:0.000840 t:37.9s +tttg: c157/213 lr:0.000813 t:38.2s +tttg: c158/213 lr:0.000785 t:38.4s +tttg: c159/213 lr:0.000759 t:38.6s +tttg: c160/213 lr:0.000732 t:38.9s +tttg: c161/213 lr:0.000706 t:39.1s +tttg: c162/213 lr:0.000681 t:39.3s +tttg: c163/213 lr:0.000655 t:39.6s +tttg: c164/213 lr:0.000631 t:39.8s +tttg: c165/213 lr:0.000606 t:40.0s +tttg: c166/213 lr:0.000582 t:40.2s +tttg: c167/213 lr:0.000559 t:40.5s +tttg: c168/213 lr:0.000536 t:40.7s +tttg: c169/213 lr:0.000513 t:40.9s +tttg: c170/213 lr:0.000491 t:41.2s +tttg: c171/213 lr:0.000469 t:41.4s +tttg: c172/213 lr:0.000447 t:41.6s +tttg: c173/213 lr:0.000426 t:41.8s +tttg: c174/213 lr:0.000406 t:42.1s +tttg: c175/213 lr:0.000386 t:42.3s +tttg: c176/213 lr:0.000366 t:42.6s +tttg: c177/213 lr:0.000347 t:42.8s +tttg: c178/213 lr:0.000329 t:43.0s +tttg: c179/213 lr:0.000311 t:43.3s +tttg: c180/213 lr:0.000293 t:43.5s +tttg: c181/213 lr:0.000276 t:43.8s +tttg: c182/213 lr:0.000259 t:44.0s +tttg: c183/213 lr:0.000243 t:44.2s +tttg: c184/213 lr:0.000227 t:44.4s +tttg: c185/213 lr:0.000212 t:44.7s +tttg: c186/213 lr:0.000197 t:44.9s +tttg: c187/213 lr:0.000183 t:45.1s +tttg: c188/213 lr:0.000170 t:45.4s +tttg: c189/213 lr:0.000156 t:45.6s +tttg: c190/213 lr:0.000144 t:45.8s +tttg: c191/213 lr:0.000132 t:46.0s +tttg: c192/213 lr:0.000120 t:46.3s +tttg: c193/213 lr:0.000109 t:46.5s +tttg: c194/213 lr:0.000098 t:46.7s +tttg: c195/213 lr:0.000088 t:47.0s +tttg: c196/213 lr:0.000079 t:47.2s +tttg: c197/213 lr:0.000070 t:47.4s +tttg: c198/213 lr:0.000062 t:47.6s +tttg: c199/213 lr:0.000054 t:47.8s +tttg: c200/213 lr:0.000046 t:48.1s +tttg: c201/213 lr:0.000039 t:48.3s +tttg: c202/213 lr:0.000033 t:48.5s +tttg: c203/213 lr:0.000027 t:48.8s +tttg: c204/213 lr:0.000022 t:49.0s +tttg: c205/213 lr:0.000018 t:49.2s +tttg: c206/213 lr:0.000013 t:49.4s +tttg: c207/213 lr:0.000010 t:49.7s +tttg: c208/213 lr:0.000007 t:49.9s +tttg: c209/213 lr:0.000004 t:50.1s +tttg: c210/213 lr:0.000002 t:50.4s +tttg: c211/213 lr:0.000001 t:50.6s +tttg: c212/213 lr:0.000000 t:50.8s +ttpr: t:273.2s +ttp: b736/782 bl:2.6717 bb:1.0414 rl:2.5799 rb:1.0353 dl:2140-2165 gd:1 +ttp: b728/782 bl:2.7553 bb:1.0671 rl:2.6033 rb:1.0397 dl:1960-1977 gd:1 +ttp: b720/782 bl:2.8159 bb:1.0756 rl:2.6267 rb:1.0438 dl:1816-1832 gd:1 +ttp: b712/782 bl:2.8281 bb:1.0767 rl:2.6454 rb:1.0470 dl:1684-1697 gd:1 +ttp: b704/782 bl:2.7376 bb:1.0210 rl:2.6528 rb:1.0448 dl:1595-1606 gd:1 +ttp: b696/782 bl:2.8079 bb:1.0733 rl:2.6638 rb:1.0469 dl:1513-1522 gd:1 +ttp: b688/782 bl:2.7454 bb:1.0474 rl:2.6690 rb:1.0469 dl:1441-1450 gd:1 +ttp: b680/782 bl:2.8011 bb:1.0537 rl:2.6765 rb:1.0473 dl:1375-1383 gd:1 +ttp: b672/782 bl:2.8964 bb:1.1050 rl:2.6879 rb:1.0504 dl:1321-1327 gd:1 +ttp: b664/782 bl:2.6993 bb:1.0407 rl:2.6885 rb:1.0499 dl:1270-1275 gd:1 +ttp: b656/782 bl:2.7429 bb:1.0355 rl:2.6908 rb:1.0493 dl:1220-1227 gd:1 +ttp: b643/782 bl:2.7893 bb:1.0634 rl:2.6947 rb:1.0498 dl:1150-1155 gd:1 +ttp: b633/782 bl:2.8205 bb:1.1005 rl:2.6993 rb:1.0517 dl:1101-1105 gd:1 +ttp: b626/782 bl:2.8070 bb:1.0430 rl:2.7030 rb:1.0514 dl:1068-1073 gd:1 +ttp: b621/782 bl:2.8308 bb:1.0842 rl:2.7071 rb:1.0524 dl:1046-1050 gd:1 +ttp: b615/782 bl:2.8328 bb:1.0636 rl:2.7110 rb:1.0528 dl:1020-1023 gd:1 +ttp: b608/782 bl:2.7284 bb:1.0297 rl:2.7115 rb:1.0521 dl:990-994 gd:1 +ttp: b593/782 bl:2.7877 bb:1.0425 rl:2.7135 rb:1.0518 dl:933-937 gd:1 +ttp: b585/782 bl:2.7605 bb:1.0644 rl:2.7147 rb:1.0522 dl:908-911 gd:1 +ttp: b577/782 bl:2.7526 bb:1.0411 rl:2.7156 rb:1.0519 dl:880-884 gd:1 +ttp: b569/782 bl:2.7594 bb:1.0542 rl:2.7166 rb:1.0519 dl:855-858 gd:1 +ttp: b564/782 bl:2.8626 bb:1.1075 rl:2.7197 rb:1.0532 dl:840-843 gd:1 +ttp: b553/782 bl:2.7653 bb:1.0594 rl:2.7207 rb:1.0533 dl:806-809 gd:1 +ttp: b545/782 bl:2.7814 bb:1.0518 rl:2.7218 rb:1.0532 dl:785-788 gd:1 +ttp: b537/782 bl:2.7066 bb:1.0234 rl:2.7216 rb:1.0527 dl:764-767 gd:1 +ttp: b531/782 bl:2.7650 bb:1.0487 rl:2.7223 rb:1.0526 dl:750-752 gd:1 +ttp: b522/782 bl:2.8217 bb:1.0847 rl:2.7240 rb:1.0532 dl:727-730 gd:1 +ttp: b517/782 bl:2.7711 bb:1.0489 rl:2.7248 rb:1.0531 dl:715-717 gd:1 +ttp: b508/782 bl:2.7581 bb:1.0307 rl:2.7253 rb:1.0527 dl:693-695 gd:1 +ttp: b500/782 bl:2.8299 bb:1.0810 rl:2.7269 rb:1.0532 dl:675-677 gd:1 +ttp: b490/782 bl:2.8458 bb:1.0874 rl:2.7286 rb:1.0537 dl:653-655 gd:1 +ttp: b482/782 bl:2.7509 bb:1.0796 rl:2.7289 rb:1.0540 dl:637-639 gd:1 +ttp: b473/782 bl:2.8256 bb:1.0751 rl:2.7302 rb:1.0543 dl:618-620 gd:1 +ttp: b465/782 bl:2.8092 bb:1.0597 rl:2.7312 rb:1.0544 dl:602-604 gd:1 +ttp: b461/782 bl:2.7723 bb:1.0572 rl:2.7317 rb:1.0544 dl:595-597 gd:1 +ttp: b452/782 bl:2.7353 bb:1.0552 rl:2.7318 rb:1.0544 dl:579-580 gd:1 +ttp: b441/782 bl:2.7112 bb:1.0436 rl:2.7315 rb:1.0543 dl:559-560 gd:1 +ttp: b433/782 bl:2.7751 bb:1.0651 rl:2.7320 rb:1.0544 dl:544-545 gd:1 +ttp: b427/782 bl:2.7458 bb:1.0609 rl:2.7322 rb:1.0545 dl:533-535 gd:1 +ttp: b418/782 bl:2.8066 bb:1.0706 rl:2.7329 rb:1.0546 dl:517-519 gd:1 +ttp: b409/782 bl:2.7057 bb:1.0453 rl:2.7327 rb:1.0545 dl:503-505 gd:1 +ttp: b401/782 bl:2.7351 bb:1.0587 rl:2.7327 rb:1.0546 dl:490-492 gd:1 +ttp: b393/782 bl:2.8499 bb:1.0854 rl:2.7338 rb:1.0549 dl:478-479 gd:1 +ttp: b386/782 bl:2.7235 bb:1.0640 rl:2.7337 rb:1.0550 dl:467-468 gd:1 +ttp: b378/782 bl:2.8194 bb:1.0969 rl:2.7344 rb:1.0553 dl:456-457 gd:1 +ttp: b370/782 bl:2.6760 bb:1.0411 rl:2.7339 rb:1.0552 dl:444-446 gd:1 +ttp: b362/782 bl:2.8186 bb:1.0657 rl:2.7346 rb:1.0553 dl:433-434 gd:1 +ttp: b354/782 bl:2.7948 bb:1.0843 rl:2.7351 rb:1.0555 dl:422-423 gd:1 +ttp: b346/782 bl:2.8419 bb:1.0845 rl:2.7359 rb:1.0557 dl:412-413 gd:1 +ttp: b338/782 bl:2.8542 bb:1.1132 rl:2.7367 rb:1.0561 dl:400-402 gd:1 +ttp: b329/782 bl:2.8268 bb:1.1026 rl:2.7374 rb:1.0565 dl:389-390 gd:1 +ttp: b321/782 bl:2.8039 bb:1.1018 rl:2.7378 rb:1.0568 dl:378-380 gd:1 +ttp: b313/782 bl:2.8251 bb:1.0880 rl:2.7384 rb:1.0570 dl:368-369 gd:1 +ttp: b305/782 bl:2.8575 bb:1.0841 rl:2.7391 rb:1.0571 dl:358-359 gd:1 +ttp: b297/782 bl:2.7950 bb:1.0589 rl:2.7395 rb:1.0572 dl:348-349 gd:1 +ttp: b289/782 bl:2.8319 bb:1.1210 rl:2.7400 rb:1.0575 dl:339-340 gd:1 +ttp: b279/782 bl:2.8494 bb:1.0891 rl:2.7406 rb:1.0577 dl:327-329 gd:1 +ttp: b270/782 bl:2.7815 bb:1.0916 rl:2.7409 rb:1.0579 dl:318-319 gd:1 +ttp: b259/782 bl:2.8563 bb:1.1391 rl:2.7415 rb:1.0583 dl:305-306 gd:1 +ttp: b251/782 bl:2.8680 bb:1.1062 rl:2.7421 rb:1.0585 dl:296-297 gd:1 +ttp: b243/782 bl:2.8286 bb:1.1032 rl:2.7425 rb:1.0588 dl:288-289 gd:1 +ttp: b235/782 bl:2.9277 bb:1.1128 rl:2.7434 rb:1.0590 dl:280-281 gd:1 +ttp: b226/782 bl:2.9358 bb:1.1421 rl:2.7443 rb:1.0594 dl:271-272 gd:1 +ttp: b218/782 bl:2.7409 bb:1.1021 rl:2.7443 rb:1.0596 dl:263-264 gd:1 +ttp: b204/782 bl:2.9059 bb:1.1304 rl:2.7449 rb:1.0599 dl:250-251 gd:1 +ttp: b196/782 bl:2.9000 bb:1.1620 rl:2.7456 rb:1.0603 dl:243-244 gd:1 +ttp: b187/782 bl:2.8956 bb:1.1165 rl:2.7462 rb:1.0605 dl:235-236 gd:1 +ttp: b178/782 bl:2.8510 bb:1.1370 rl:2.7465 rb:1.0608 dl:227-228 gd:1 +ttp: b165/782 bl:2.9411 bb:1.1638 rl:2.7472 rb:1.0611 dl:216-217 gd:1 +ttp: b157/782 bl:2.8213 bb:1.1120 rl:2.7475 rb:1.0613 dl:209-210 gd:1 +ttp: b148/782 bl:2.9737 bb:1.1558 rl:2.7482 rb:1.0616 dl:202-203 gd:1 +ttp: b137/782 bl:2.9332 bb:1.1820 rl:2.7488 rb:1.0620 dl:193-194 gd:1 +ttp: b125/782 bl:3.0096 bb:1.1927 rl:2.7496 rb:1.0623 dl:184-185 gd:1 +ttp: b116/782 bl:3.0131 bb:1.1915 rl:2.7503 rb:1.0627 dl:177-178 gd:1 +ttp: b108/782 bl:2.8686 bb:1.1017 rl:2.7507 rb:1.0628 dl:171-172 gd:1 +ttp: b95/782 bl:3.0039 bb:1.2232 rl:2.7513 rb:1.0632 dl:161-162 gd:1 +ttp: b87/782 bl:3.0217 bb:1.2078 rl:2.7520 rb:1.0635 dl:155-156 gd:1 +ttp: b76/782 bl:3.0561 bb:1.2261 rl:2.7527 rb:1.0639 dl:147-148 gd:1 +ttp: b64/782 bl:2.9896 bb:1.2391 rl:2.7532 rb:1.0643 dl:138-139 gd:1 +ttp: b56/782 bl:3.0573 bb:1.2049 rl:2.7539 rb:1.0646 dl:131-132 gd:1 +ttp: b43/782 bl:3.0018 bb:1.1937 rl:2.7543 rb:1.0648 dl:121-122 gd:1 +ttp: b33/782 bl:3.0960 bb:1.2120 rl:2.7550 rb:1.0651 dl:113-114 gd:1 +ttp: b26/782 bl:3.0853 bb:1.2579 rl:2.7555 rb:1.0654 dl:107-107 gd:1 +ttp: b12/782 bl:3.1914 bb:1.2439 rl:2.7561 rb:1.0656 dl:92-93 gd:1 +ttp: b2/782 bl:3.1528 bb:1.1693 rl:2.7566 rb:1.0658 dl:70-75 gd:1 +quantized_ttt_phased val_loss:2.76956462 val_bpb:1.07218477 eval_time:455945ms +total_eval_time:455.9s + +=== Persisting seed-0 checkpoint (before log parse) === +Checkpoints saved: 145M + +=== Gate A: Parsing results === + val_bpb: 1.07218477 (published: 1.07216564 ceiling: 1.07516564) + eval_time: 455.9s (limit: 600s) + artifact_size: 15999394 bytes (limit: 15997520) + +GATE_A: FAIL — artifact 15999394B > 15997520B seed-0 headroom limit + Log: /workspace/parameter-golf/runs/seed0_log.txt diff --git a/scripts/pgolf_field_guide_audit.py b/scripts/pgolf_field_guide_audit.py new file mode 100755 index 0000000000..fa946e59a5 --- /dev/null +++ b/scripts/pgolf_field_guide_audit.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +"""Static Field Guide audit helper for Parameter Golf submissions. + +This is intentionally conservative: it does not prove legality by itself. +It decodes packed train_gpt.py wrappers, scans for common Issue #1017 risks, +and emits a small JSON report that can be included next to reproduction logs. +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import lzma +import re +from pathlib import Path + + +def load_source(path: Path) -> tuple[str, dict[str, object]]: + raw = path.read_text(encoding="utf-8") + meta: dict[str, object] = {"path": str(path), "packed_wrapper": False} + marker = 'B.b85decode("' + if marker not in raw: + return raw, meta + + try: + start = raw.index(marker) + len(marker) + end = raw.index('"),format=L.FORMAT_RAW', start) + payload = raw[start:end] + decoded = lzma.decompress( + base64.b85decode(payload), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2}], + ) + meta.update( + { + "packed_wrapper": True, + "packed_bytes": len(raw.encode("utf-8")), + "decoded_bytes": len(decoded), + } + ) + return decoded.decode("utf-8"), meta + except Exception as exc: # pragma: no cover - diagnostic path + meta["decode_error"] = repr(exc) + return raw, meta + + +def extract_function(source: str, name: str) -> str: + m = re.search(rf"^def {re.escape(name)}\(", source, re.M) + if not m: + return "" + start = m.start() + next_def = re.search(r"^def [A-Za-z_][A-Za-z0-9_]*\(", source[m.end() :], re.M) + if not next_def: + return source[start:] + return source[start : m.end() + next_def.start()] + + +def getenv_default(source: str, key: str) -> str | None: + m = re.search(rf'os\.environ\.get\("{re.escape(key)}",\s*([^)]+)\)', source) + return m.group(1).strip() if m else None + + +def audit(source: str, meta: dict[str, object]) -> dict[str, object]: + lower = source.lower() + eval_ttt = extract_function(source, "eval_val_ttt") + byte_luts = extract_function(source, "build_sentencepiece_luts") + + no_grad_pos = eval_ttt.find("torch.no_grad") + score_pos = min( + [p for p in [eval_ttt.find("loss_sum +="), eval_ttt.find("byte_count +=")] if p >= 0], + default=-1, + ) + step_pos = eval_ttt.find("optimizer.step") + + suspicious_terms = { + "slot": "slot" in lower, + "ngram": "ngram" in lower or "n-gram" in lower, + "ppm": "ppm" in lower, + "etlb": "etlb" in lower, + "logit_bias": "logit_bias" in lower, + "caseops_or_casefold": "caseops" in lower or "casefold" in lower, + } + + checks = { + "condition_1_causal_prefix_static": { + "status": "pass" + if "causal=True" in source or "is_causal=True" in source + else "review", + "evidence": "causal attention flag present; static audit cannot prove every data path", + }, + "condition_2_full_distribution_static": { + "status": "pass" + if "F.cross_entropy" in source and not suspicious_terms["logit_bias"] + else "review", + "evidence": "uses F.cross_entropy on full logits; no logit_bias token-only path found" + if not suspicious_terms["logit_bias"] + else "logit_bias found; inspect full-vocab normalization manually", + }, + "condition_3_score_before_update_static": { + "status": "pass" + if no_grad_pos >= 0 and score_pos >= 0 and step_pos >= 0 and no_grad_pos < score_pos < step_pos + else "review", + "evidence": { + "torch_no_grad_pos": no_grad_pos, + "score_accum_pos": score_pos, + "optimizer_step_pos": step_pos, + }, + }, + "condition_4_single_pass_static": { + "status": "pass" + if "for ci in range" in eval_ttt and "optimizer.step" in eval_ttt and "min(" not in eval_ttt[:500] + else "review", + "evidence": "eval_val_ttt iterates chunks once; no obvious min-over-runs pattern in function head", + }, + "byte_accounting_static": { + "status": "pass" + if "base_bytes_np" in byte_luts + and "has_leading_space" in byte_luts + and "is_boundary_token" in byte_luts + and "byte_count +=" in source + else "review", + "evidence": "SentencePiece byte LUT plus leading-space correction found", + }, + } + + return { + "metadata": meta, + "defaults": { + "TTT_ENABLED": getenv_default(source, "TTT_ENABLED"), + "TTT_LR": getenv_default(source, "TTT_LR"), + "TTT_EPOCHS": getenv_default(source, "TTT_EPOCHS"), + "TTT_CHUNK_TOKENS": getenv_default(source, "TTT_CHUNK_TOKENS"), + "MUON_WD": getenv_default(source, "MUON_WD"), + "MUON_WD_MLP": getenv_default(source, "MUON_WD_MLP"), + "SLIDING_WINDOW_ENABLED": getenv_default(source, "SLIDING_WINDOW_ENABLED"), + }, + "suspicious_terms": suspicious_terms, + "checks": checks, + "summary": { + "pass": sum(1 for item in checks.values() if item["status"] == "pass"), + "review": sum(1 for item in checks.values() if item["status"] == "review"), + }, + } + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("train_gpt", type=Path) + parser.add_argument("--write-decoded", type=Path) + parser.add_argument("--output", type=Path) + args = parser.parse_args() + + source, meta = load_source(args.train_gpt) + if args.write_decoded: + args.write_decoded.parent.mkdir(parents=True, exist_ok=True) + args.write_decoded.write_text(source, encoding="utf-8") + + report = audit(source, meta) + text = json.dumps(report, indent=2, sort_keys=True) + if args.output: + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(text + "\n", encoding="utf-8") + print(text) + + +if __name__ == "__main__": + main() diff --git a/scripts/runpod_pr1812_audit_repro.sh b/scripts/runpod_pr1812_audit_repro.sh new file mode 100755 index 0000000000..4acef92c1d --- /dev/null +++ b/scripts/runpod_pr1812_audit_repro.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash +# Reproduce and audit PR #1812 for a Field Guide aligned non-record package. +# +# Modes: +# audit Download exact PR #1812 files and run static Issue #1017 audit. +# seed42 Run one independent seed-42 reproduction. +# two-seed Run seed 42 and seed 314 for an audit package. +# +# Usage on an 8xH100 RunPod pod from a parameter-golf checkout: +# bash scripts/runpod_pr1812_audit_repro.sh audit +# bash scripts/runpod_pr1812_audit_repro.sh seed42 +# +# Data assumptions: +# data/datasets/fineweb10B_sp8192 and data/tokenizers/fineweb_8192_bpe.model +# already exist in this checkout. Override DATA_DIR if staged elsewhere. + +set -euo pipefail + +MODE="${1:-audit}" +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +WORK_DIR="${ROOT}/runs/pr1812_audit_repro" +PR_DIR="${WORK_DIR}/upstream_pr1812" +DATA_DIR="${DATA_DIR:-${ROOT}/data}" + +TRAIN_URL="${TRAIN_URL:-https://github.com/openai/parameter-golf/raw/1350423f2b26d20b3c384f194e8f66d06a6428c2/records%2Ftrack_10min_16mb%2F2026-04-25_SP8192_3LayerRecur_LegalTTT_4ep%2Ftrain_gpt.py}" +README_URL="${README_URL:-https://github.com/openai/parameter-golf/raw/1350423f2b26d20b3c384f194e8f66d06a6428c2/records%2Ftrack_10min_16mb%2F2026-04-25_SP8192_3LayerRecur_LegalTTT_4ep%2FREADME.md}" +SUBMISSION_URL="${SUBMISSION_URL:-https://github.com/openai/parameter-golf/raw/1350423f2b26d20b3c384f194e8f66d06a6428c2/records%2Ftrack_10min_16mb%2F2026-04-25_SP8192_3LayerRecur_LegalTTT_4ep%2Fsubmission.json}" + +mkdir -p "${PR_DIR}" + +fetch_pr_files() { + if [ ! -f "${PR_DIR}/train_gpt.py" ]; then + echo "[setup] downloading PR #1812 train_gpt.py" + curl -fsSL -o "${PR_DIR}/train_gpt.py" "${TRAIN_URL}" + fi + if [ ! -f "${PR_DIR}/README.md" ]; then + curl -fsSL -o "${PR_DIR}/README.md" "${README_URL}" + fi + if [ ! -f "${PR_DIR}/submission.json" ]; then + curl -fsSL -o "${PR_DIR}/submission.json" "${SUBMISSION_URL}" + fi +} + +run_audit() { + fetch_pr_files + python3 "${ROOT}/scripts/pgolf_field_guide_audit.py" \ + "${PR_DIR}/train_gpt.py" \ + --write-decoded "${PR_DIR}/train_gpt.decoded.py" \ + --output "${WORK_DIR}/field_guide_static_audit.json" +} + +check_data() { + local sp_dir="${DATA_DIR}/datasets/fineweb10B_sp8192" + local tok="${DATA_DIR}/tokenizers/fineweb_8192_bpe.model" + if [ ! -d "${sp_dir}" ]; then + echo "FATAL: missing ${sp_dir}" >&2 + echo "Download first: MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192" >&2 + exit 1 + fi + if [ ! -f "${tok}" ]; then + echo "FATAL: missing ${tok}" >&2 + exit 1 + fi +} + +run_seed() { + local seed="$1" + local out_dir="${WORK_DIR}/seed${seed}" + + fetch_pr_files + check_data + mkdir -p "${out_dir}" + + echo "" + echo "=== PR #1812 reproduction: seed=${seed} ===" + echo "Output: ${out_dir}" + + ( + cd "${PR_DIR}" + export DATA_DIR + export SEED="${seed}" + export TTT_ENABLED=1 + export TTT_LR=0.005 + export TTT_EPOCHS=4 + export RUN_ID="pr1812_audit_s${seed}" + export ARTIFACT_DIR="${out_dir}" + export PYTHONUNBUFFERED=1 + torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee "${out_dir}/console.log" + ) + + grep -E "val_bpb|eval_time|Serialized model|Total submission size|train_time|stopping_early" \ + "${out_dir}/console.log" | tail -40 || true +} + +case "${MODE}" in + audit) + run_audit + ;; + seed42) + run_audit + run_seed 42 + ;; + two-seed) + run_audit + run_seed 42 + run_seed 314 + ;; + *) + echo "Usage: $0 [audit|seed42|two-seed]" >&2 + exit 2 + ;; +esac